diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-11-11 21:29:50 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-11-11 21:29:50 +0300 |
commit | 8ed0994a0abaa8ceffe075d76d6caa0d77beac56 (patch) | |
tree | aeeed124036058367c8713c9c0d60177c9e4b14b | |
parent | a665438965bffdf99136b3bc02e0d968fde54835 (diff) | |
parent | 489620ce40d8f1d48cd11ed317034ef18067f228 (diff) |
Merge pull request #834 from killeent/lognormal-fix
Fix implementation of logNormal
-rw-r--r-- | lib/TH/THRandom.c | 4 | ||||
-rw-r--r-- | test/test.lua | 11 |
2 files changed, 12 insertions, 3 deletions
diff --git a/lib/TH/THRandom.c b/lib/TH/THRandom.c index 55ee943..fbaf282 100644 --- a/lib/TH/THRandom.c +++ b/lib/TH/THRandom.c @@ -255,10 +255,8 @@ double THRandom_cauchy(THGenerator *_generator, double median, double sigma) M'enfin. */ double THRandom_logNormal(THGenerator *_generator, double mean, double stdv) { - double zm = mean*mean; - double zs = stdv*stdv; THArgCheck(stdv > 0, 2, "standard deviation must be strictly positive"); - return(exp(THRandom_normal(_generator, log(zm/sqrt(zs + zm)), sqrt(log(zs/zm+1)) ))); + return(exp(THRandom_normal(_generator, mean, stdv))); } int THRandom_geometric(THGenerator *_generator, double p) diff --git a/test/test.lua b/test/test.lua index 4290036..81da692 100644 --- a/test/test.lua +++ b/test/test.lua @@ -3435,6 +3435,17 @@ function torchtest.bernoulli() mytester:assert(isBinary(t), 'Sample from torch.bernoulli is not binary') end +function torchtest.logNormal() + local t = torch.FloatTensor(10, 10) + local mean, std = torch.uniform(), 0.1 * torch.uniform() + local tolerance = 0.01 + + t:logNormal(mean, std) + local logt = t:log() + mytester:assertalmosteq(logt:mean(), mean, tolerance, 'mean is wrong') + mytester:assertalmosteq(logt:std(), std, tolerance, 'tolerance is wrong') +end + function torch.test(tests) torch.setheaptracking(true) math.randomseed(os.time()) |