Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-11-11 21:29:50 +0300
committerGitHub <noreply@github.com>2016-11-11 21:29:50 +0300
commit8ed0994a0abaa8ceffe075d76d6caa0d77beac56 (patch)
treeaeeed124036058367c8713c9c0d60177c9e4b14b
parenta665438965bffdf99136b3bc02e0d968fde54835 (diff)
parent489620ce40d8f1d48cd11ed317034ef18067f228 (diff)
Merge pull request #834 from killeent/lognormal-fix
Fix implementation of logNormal
-rw-r--r--lib/TH/THRandom.c4
-rw-r--r--test/test.lua11
2 files changed, 12 insertions, 3 deletions
diff --git a/lib/TH/THRandom.c b/lib/TH/THRandom.c
index 55ee943..fbaf282 100644
--- a/lib/TH/THRandom.c
+++ b/lib/TH/THRandom.c
@@ -255,10 +255,8 @@ double THRandom_cauchy(THGenerator *_generator, double median, double sigma)
M'enfin. */
double THRandom_logNormal(THGenerator *_generator, double mean, double stdv)
{
- double zm = mean*mean;
- double zs = stdv*stdv;
THArgCheck(stdv > 0, 2, "standard deviation must be strictly positive");
- return(exp(THRandom_normal(_generator, log(zm/sqrt(zs + zm)), sqrt(log(zs/zm+1)) )));
+ return(exp(THRandom_normal(_generator, mean, stdv)));
}
int THRandom_geometric(THGenerator *_generator, double p)
diff --git a/test/test.lua b/test/test.lua
index 4290036..81da692 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -3435,6 +3435,17 @@ function torchtest.bernoulli()
mytester:assert(isBinary(t), 'Sample from torch.bernoulli is not binary')
end
+function torchtest.logNormal()
+ local t = torch.FloatTensor(10, 10)
+ local mean, std = torch.uniform(), 0.1 * torch.uniform()
+ local tolerance = 0.01
+
+ t:logNormal(mean, std)
+ local logt = t:log()
+ mytester:assertalmosteq(logt:mean(), mean, tolerance, 'mean is wrong')
+ mytester:assertalmosteq(logt:std(), std, tolerance, 'tolerance is wrong')
+end
+
function torch.test(tests)
torch.setheaptracking(true)
math.randomseed(os.time())