Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/argcheck.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJonas Gehring <jgehring@fb.com>2016-03-10 22:14:26 +0300
committerJonas Gehring <jgehring@fb.com>2016-03-10 22:14:26 +0300
commit89c38818d575cf5571c2f8d0cb2f8ed4867c821e (patch)
tree8c9aff38e7afd5fa73a9456db8a0c16acfe8ade0 /env.lua
parentaa0f434b6673b8713c54154bc9f8d8f095be20df (diff)
Fixed torch type checking
Enforce exact type matching for non-pattern strings. This fixes false positives like a 'torch.LongTensor' object matching 'torch.Long'. Also, a getmetatable() call was missing for the non-torch version of env.type().
Diffstat (limited to 'env.lua')
-rw-r--r--env.lua11
1 files changed, 8 insertions, 3 deletions
diff --git a/env.lua b/env.lua
index f62ebb6..082cd63 100644
--- a/env.lua
+++ b/env.lua
@@ -13,6 +13,7 @@ function env.istype(obj, typename)
end
function env.type(obj)
+ local mt = getmetatable(obj)
if type(mt) == 'table' then
local objtype = rawget(mt, '__typename')
if objtype then
@@ -28,13 +29,17 @@ if pcall(require, 'torch') then
local thname = torch.typename(obj)
if thname then
-- __typename (see below) might be absent
- if thname:match(typename) then
+ local match = thname:match(typename)
+ if match and (match ~= typename or match == thname) then
return true
end
local mt = torch.getmetatable(thname)
while mt do
- if mt.__typename and mt.__typename:match(typename) then
- return true
+ if mt.__typename then
+ match = mt.__typename:match(typename)
+ if match and (match ~= typename or match == mt.__typename) then
+ return true
+ end
end
mt = getmetatable(mt)
end