diff options
author | Jonas Gehring <jgehring@fb.com> | 2016-03-10 22:14:26 +0300 |
---|---|---|
committer | Jonas Gehring <jgehring@fb.com> | 2016-03-10 22:14:26 +0300 |
commit | 89c38818d575cf5571c2f8d0cb2f8ed4867c821e (patch) | |
tree | 8c9aff38e7afd5fa73a9456db8a0c16acfe8ade0 /env.lua | |
parent | aa0f434b6673b8713c54154bc9f8d8f095be20df (diff) |
Fixed torch type checking
Enforce exact type matching for non-pattern strings. This fixes
false positives like a 'torch.LongTensor' object matching 'torch.Long'.
Also, a getmetatable() call was missing for the non-torch version
of env.type().
Diffstat (limited to 'env.lua')
-rw-r--r-- | env.lua | 11 |
1 files changed, 8 insertions, 3 deletions
@@ -13,6 +13,7 @@ function env.istype(obj, typename) end function env.type(obj) + local mt = getmetatable(obj) if type(mt) == 'table' then local objtype = rawget(mt, '__typename') if objtype then @@ -28,13 +29,17 @@ if pcall(require, 'torch') then local thname = torch.typename(obj) if thname then -- __typename (see below) might be absent - if thname:match(typename) then + local match = thname:match(typename) + if match and (match ~= typename or match == thname) then return true end local mt = torch.getmetatable(thname) while mt do - if mt.__typename and mt.__typename:match(typename) then - return true + if mt.__typename then + match = mt.__typename:match(typename) + if match and (match ~= typename or match == mt.__typename) then + return true + end end mt = getmetatable(mt) end |