diff options
author | Andreas Köpf <andreas.koepf@xamla.com> | 2016-02-27 04:48:13 +0300 |
---|---|---|
committer | Andreas Köpf <andreas.koepf@xamla.com> | 2016-03-05 01:17:09 +0300 |
commit | fdb14f22e7db5511294ac1591118303a37f48bd9 (patch) | |
tree | e9a440502b5d6f878d1ba4b428993a5c178af1e6 /test | |
parent | 1cf98511b39a486f74aa4ecce02f8de422101a2d (diff) |
Add math functions trunc, frac, rsqrt, lerp
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 81 |
1 files changed, 81 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua index d1bb96c..93b71c6 100644 --- a/test/test.lua +++ b/test/test.lua @@ -162,6 +162,26 @@ function torchtest.sqrt() mytester:assertlt(maxerrnc, precision, 'error in torch.functionname - non-contiguous') end +function torchtest.rsqrt() + local function TH_rsqrt(x) + return 1 / math.sqrt(x) + end + + local f + local t = genericSingleOpTest:gsub('functionname', 'rsqrt'):gsub('math.rsqrt', 'TH_rsqrt') + local env = { TH_rsqrt=TH_rsqrt, torch=torch, math=math } + if not setfenv then -- Lua 5.2 + f = load(t, 'test', 't', env) + else + f = loadstring(t) + setfenv(f, env) + end + + local maxerrc, maxerrnc = f() + mytester:assertlt(maxerrc, precision, 'error in torch.functionname - contiguous') + mytester:assertlt(maxerrnc, precision, 'error in torch.functionname - non-contiguous') +end + function torchtest.sigmoid() -- cant use genericSingleOpTest, since `math.sigmoid` doesnt exist, have to use -- `torch.sigmoid` instead @@ -208,6 +228,46 @@ function torchtest.ceil() mytester:assertlt(maxerrnc, precision, 'error in torch.functionname - non-contiguous') end +function torchtest.frac() + local function TH_frac(x) + return math.fmod(x, 1) + end + + local f + local t = genericSingleOpTest:gsub('functionname', 'frac'):gsub('math.frac', 'TH_frac') + local env = { TH_frac=TH_frac, torch=torch, math=math } + if not setfenv then -- Lua 5.2 + f = load(t, 'test', 't', env) + else + f = loadstring(t) + setfenv(f, env) + end + + local maxerrc, maxerrnc = f() + mytester:assertlt(maxerrc, precision, 'error in torch.functionname - contiguous') + mytester:assertlt(maxerrnc, precision, 'error in torch.functionname - non-contiguous') +end + +function torchtest.trunc() + local function TH_trunc(x) + return x - math.fmod(x, 1) + end + + local f + local t = genericSingleOpTest:gsub('functionname', 'trunc'):gsub('math.trunc', 'TH_trunc') + local env = { TH_trunc=TH_trunc, torch=torch, math=math } + if not setfenv then -- Lua 5.2 + f = load(t, 'test', 't', env) + else + f = loadstring(t) + setfenv(f, env) + end + + local maxerrc, maxerrnc = f() + mytester:assertlt(maxerrc, precision, 'error in torch.functionname - contiguous') + mytester:assertlt(maxerrnc, precision, 'error in torch.functionname - non-contiguous') +end + function torchtest.round() -- [res] torch.round([res,] x) -- contiguous @@ -425,6 +485,27 @@ function torchtest.cmin() 'error in torch.cmin(tensor, scalar).') end +function torchtest.lerp() + local function TH_lerp(a, b, weight) + return a + weight * (b-a); + end + + local a = torch.rand(msize, msize) + local b = torch.rand(msize, msize) + local w = math.random() + local result = torch.lerp(a, b, w) + local expected = a:new() + expected:map2(a, b, function(_, a, b) return TH_lerp(a, b, w) end) + mytester:assertTensorEq(expected, result, precision, 'error in torch.lerp(tensor, tensor, weight)') + + local a = (math.random()*2-1) * 100000 + local b = (math.random()*2-1) * 100000 + local w = math.random() + local result = torch.lerp(a, b, w) + local expected = TH_lerp(a, b, w) + mytester:assertalmosteq(expected, result, precision, 'error in torch.lerp(scalar, scalar, weight)') +end + for i, v in ipairs{{10}, {5, 5}} do torchtest['allAndAny' .. i] = function () |