diff options
author | David Belanger <david.b.belanger@gmail.com> | 2015-11-25 00:37:16 +0300 |
---|---|---|
committer | David Belanger <david.b.belanger@gmail.com> | 2015-11-25 00:37:16 +0300 |
commit | 7e4d57e26eb6b2c294d58522d08d22f3726c4551 (patch) | |
tree | 1726989870e450c7bcd0591755cbb400add6e91a /PushTable.lua | |
parent | 888c940fe953e12b7277f4b9b09761fc73818062 (diff) |
fix infinite recursion in PullTable and PushTable when converting to different types
Diffstat (limited to 'PushTable.lua')
-rw-r--r-- | PushTable.lua | 14 |
1 files changed, 14 insertions, 0 deletions
diff --git a/PushTable.lua b/PushTable.lua index b7cfb64..3131c69 100644 --- a/PushTable.lua +++ b/PushTable.lua @@ -60,4 +60,18 @@ function PushTable:updateGradInput(inputTable, gradOutputTable) end +function PushTable:type(type, tensorCache) + assert(type, 'PullTable: must provide a type to convert to') + + tensorCache = tensorCache or {} + + -- find all tensors and convert them + for key,param in pairs(self) do + if(key ~= "_pulls") then + self[key] = nn.utils.recursiveType(param, type, tensorCache) + end + end + return self +end + |