diff options
author | David Belanger <david.b.belanger@gmail.com> | 2015-11-25 00:37:16 +0300 |
---|---|---|
committer | David Belanger <david.b.belanger@gmail.com> | 2015-11-25 00:37:16 +0300 |
commit | 7e4d57e26eb6b2c294d58522d08d22f3726c4551 (patch) | |
tree | 1726989870e450c7bcd0591755cbb400add6e91a /PullTable.lua | |
parent | 888c940fe953e12b7277f4b9b09761fc73818062 (diff) |
fix infinite recursion in PullTable and PushTable when converting to different types
Diffstat (limited to 'PullTable.lua')
-rw-r--r-- | PullTable.lua | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/PullTable.lua b/PullTable.lua index 1fed379..1413a07 100644 --- a/PullTable.lua +++ b/PullTable.lua @@ -58,3 +58,20 @@ function PullTable:updateGradInput(inputTable, gradOutputTable) end return self.gradInput end + + +function PullTable:type(type, tensorCache) + assert(type, 'PullTable: must provide a type to convert to') + + tensorCache = tensorCache or {} + + -- find all tensors and convert them + for key,param in pairs(self) do + if(key ~= "_push") then + self[key] = nn.utils.recursiveType(param, type, tensorCache) + end + end + + return self +end + |