Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorAndreas Köpf <andreas.koepf@xamla.com>2016-02-27 04:48:13 +0300
committerAndreas Köpf <andreas.koepf@xamla.com>2016-03-05 01:17:09 +0300
commitfdb14f22e7db5511294ac1591118303a37f48bd9 (patch)
treee9a440502b5d6f878d1ba4b428993a5c178af1e6 /test
parent1cf98511b39a486f74aa4ecce02f8de422101a2d (diff)
Add math functions trunc, frac, rsqrt, lerp
Diffstat (limited to 'test')
-rw-r--r--test/test.lua81
1 files changed, 81 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua
index d1bb96c..93b71c6 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -162,6 +162,26 @@ function torchtest.sqrt()
mytester:assertlt(maxerrnc, precision, 'error in torch.functionname - non-contiguous')
end
+function torchtest.rsqrt()
+ local function TH_rsqrt(x)
+ return 1 / math.sqrt(x)
+ end
+
+ local f
+ local t = genericSingleOpTest:gsub('functionname', 'rsqrt'):gsub('math.rsqrt', 'TH_rsqrt')
+ local env = { TH_rsqrt=TH_rsqrt, torch=torch, math=math }
+ if not setfenv then -- Lua 5.2
+ f = load(t, 'test', 't', env)
+ else
+ f = loadstring(t)
+ setfenv(f, env)
+ end
+
+ local maxerrc, maxerrnc = f()
+ mytester:assertlt(maxerrc, precision, 'error in torch.functionname - contiguous')
+ mytester:assertlt(maxerrnc, precision, 'error in torch.functionname - non-contiguous')
+end
+
function torchtest.sigmoid()
-- cant use genericSingleOpTest, since `math.sigmoid` doesnt exist, have to use
-- `torch.sigmoid` instead
@@ -208,6 +228,46 @@ function torchtest.ceil()
mytester:assertlt(maxerrnc, precision, 'error in torch.functionname - non-contiguous')
end
+function torchtest.frac()
+ local function TH_frac(x)
+ return math.fmod(x, 1)
+ end
+
+ local f
+ local t = genericSingleOpTest:gsub('functionname', 'frac'):gsub('math.frac', 'TH_frac')
+ local env = { TH_frac=TH_frac, torch=torch, math=math }
+ if not setfenv then -- Lua 5.2
+ f = load(t, 'test', 't', env)
+ else
+ f = loadstring(t)
+ setfenv(f, env)
+ end
+
+ local maxerrc, maxerrnc = f()
+ mytester:assertlt(maxerrc, precision, 'error in torch.functionname - contiguous')
+ mytester:assertlt(maxerrnc, precision, 'error in torch.functionname - non-contiguous')
+end
+
+function torchtest.trunc()
+ local function TH_trunc(x)
+ return x - math.fmod(x, 1)
+ end
+
+ local f
+ local t = genericSingleOpTest:gsub('functionname', 'trunc'):gsub('math.trunc', 'TH_trunc')
+ local env = { TH_trunc=TH_trunc, torch=torch, math=math }
+ if not setfenv then -- Lua 5.2
+ f = load(t, 'test', 't', env)
+ else
+ f = loadstring(t)
+ setfenv(f, env)
+ end
+
+ local maxerrc, maxerrnc = f()
+ mytester:assertlt(maxerrc, precision, 'error in torch.functionname - contiguous')
+ mytester:assertlt(maxerrnc, precision, 'error in torch.functionname - non-contiguous')
+end
+
function torchtest.round()
-- [res] torch.round([res,] x)
-- contiguous
@@ -425,6 +485,27 @@ function torchtest.cmin()
'error in torch.cmin(tensor, scalar).')
end
+function torchtest.lerp()
+ local function TH_lerp(a, b, weight)
+ return a + weight * (b-a);
+ end
+
+ local a = torch.rand(msize, msize)
+ local b = torch.rand(msize, msize)
+ local w = math.random()
+ local result = torch.lerp(a, b, w)
+ local expected = a:new()
+ expected:map2(a, b, function(_, a, b) return TH_lerp(a, b, w) end)
+ mytester:assertTensorEq(expected, result, precision, 'error in torch.lerp(tensor, tensor, weight)')
+
+ local a = (math.random()*2-1) * 100000
+ local b = (math.random()*2-1) * 100000
+ local w = math.random()
+ local result = torch.lerp(a, b, w)
+ local expected = TH_lerp(a, b, w)
+ mytester:assertalmosteq(expected, result, precision, 'error in torch.lerp(scalar, scalar, weight)')
+end
+
for i, v in ipairs{{10}, {5, 5}} do
torchtest['allAndAny' .. i] =
function ()