Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/argcheck.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorJonas Gehring <jgehring@fb.com>2016-03-10 22:14:26 +0300
committerJonas Gehring <jgehring@fb.com>2016-03-10 22:14:26 +0300
commit89c38818d575cf5571c2f8d0cb2f8ed4867c821e (patch)
tree8c9aff38e7afd5fa73a9456db8a0c16acfe8ade0 /test
parentaa0f434b6673b8713c54154bc9f8d8f095be20df (diff)
Fixed torch type checking
Enforce exact type matching for non-pattern strings. This fixes false positives like a 'torch.LongTensor' object matching 'torch.Long'. Also, a getmetatable() call was missing for the non-torch version of env.type().
Diffstat (limited to 'test')
-rw-r--r--test/test.lua14
1 files changed, 14 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua
index 19b969d..b3778bc 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1,4 +1,5 @@
local argcheck = require 'argcheck'
+local env = require 'argcheck.env'
function addfive(x)
return string.format('%f + 5 = %f', x, x+5)
@@ -321,4 +322,17 @@ assert(addstuff{x=3, y=4} == '3.000000 + 4.000000 = 7.000000 [msg=NULL]')
assert(addstuff(3, 4, 'paf') == '3.000000 + 4.000000 = 7.000000 [msg=paf]')
assert(addstuff{x=3, y=4, msg='paf'} == '3.000000 + 4.000000 = 7.000000 [msg=paf]')
+assert(env.type('string') == 'string')
+assert(env.type(foobar) == 'foobar')
+
+if pcall(require, 'torch') then
+ local t = torch.LongTensor()
+ assert(env.type(t) == 'torch.LongTensor')
+ assert(env.istype(t, 'torch.LongTensor') == true)
+ assert(env.istype(t, 'torch.*Tensor') == true)
+ assert(env.istype(t, '.*Long') == true)
+ assert(env.istype(t, 'torch.IntTensor') == false)
+ assert(env.istype(t, 'torch.Long') == false)
+end
+
print('PASSED')