diff options
author | Jonas Gehring <jgehring@fb.com> | 2016-03-10 22:14:26 +0300 |
---|---|---|
committer | Jonas Gehring <jgehring@fb.com> | 2016-03-10 22:14:26 +0300 |
commit | 89c38818d575cf5571c2f8d0cb2f8ed4867c821e (patch) | |
tree | 8c9aff38e7afd5fa73a9456db8a0c16acfe8ade0 | |
parent | aa0f434b6673b8713c54154bc9f8d8f095be20df (diff) |
Fixed torch type checking
Enforce exact type matching for non-pattern strings. This fixes
false positives like a 'torch.LongTensor' object matching 'torch.Long'.
Also, a getmetatable() call was missing for the non-torch version
of env.type().
-rw-r--r-- | env.lua | 11 | ||||
-rw-r--r-- | test/test.lua | 14 |
2 files changed, 22 insertions, 3 deletions
@@ -13,6 +13,7 @@ function env.istype(obj, typename) end function env.type(obj) + local mt = getmetatable(obj) if type(mt) == 'table' then local objtype = rawget(mt, '__typename') if objtype then @@ -28,13 +29,17 @@ if pcall(require, 'torch') then local thname = torch.typename(obj) if thname then -- __typename (see below) might be absent - if thname:match(typename) then + local match = thname:match(typename) + if match and (match ~= typename or match == thname) then return true end local mt = torch.getmetatable(thname) while mt do - if mt.__typename and mt.__typename:match(typename) then - return true + if mt.__typename then + match = mt.__typename:match(typename) + if match and (match ~= typename or match == mt.__typename) then + return true + end end mt = getmetatable(mt) end diff --git a/test/test.lua b/test/test.lua index 19b969d..b3778bc 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1,4 +1,5 @@ local argcheck = require 'argcheck' +local env = require 'argcheck.env' function addfive(x) return string.format('%f + 5 = %f', x, x+5) @@ -321,4 +322,17 @@ assert(addstuff{x=3, y=4} == '3.000000 + 4.000000 = 7.000000 [msg=NULL]') assert(addstuff(3, 4, 'paf') == '3.000000 + 4.000000 = 7.000000 [msg=paf]') assert(addstuff{x=3, y=4, msg='paf'} == '3.000000 + 4.000000 = 7.000000 [msg=paf]') +assert(env.type('string') == 'string') +assert(env.type(foobar) == 'foobar') + +if pcall(require, 'torch') then + local t = torch.LongTensor() + assert(env.type(t) == 'torch.LongTensor') + assert(env.istype(t, 'torch.LongTensor') == true) + assert(env.istype(t, 'torch.*Tensor') == true) + assert(env.istype(t, '.*Long') == true) + assert(env.istype(t, 'torch.IntTensor') == false) + assert(env.istype(t, 'torch.Long') == false) +end + print('PASSED') |