Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Belanger <david.b.belanger@gmail.com>2015-11-25 00:37:16 +0300
committerDavid Belanger <david.b.belanger@gmail.com>2015-11-25 00:37:16 +0300
commit7e4d57e26eb6b2c294d58522d08d22f3726c4551 (patch)
tree1726989870e450c7bcd0591755cbb400add6e91a /PushTable.lua
parent888c940fe953e12b7277f4b9b09761fc73818062 (diff)
fix infinite recursion in PullTable and PushTable when converting to different types
Diffstat (limited to 'PushTable.lua')
-rw-r--r--PushTable.lua14
1 files changed, 14 insertions, 0 deletions
diff --git a/PushTable.lua b/PushTable.lua
index b7cfb64..3131c69 100644
--- a/PushTable.lua
+++ b/PushTable.lua
@@ -60,4 +60,18 @@ function PushTable:updateGradInput(inputTable, gradOutputTable)
end
+function PushTable:type(type, tensorCache)
+ assert(type, 'PullTable: must provide a type to convert to')
+
+ tensorCache = tensorCache or {}
+
+ -- find all tensors and convert them
+ for key,param in pairs(self) do
+ if(key ~= "_pulls") then
+ self[key] = nn.utils.recursiveType(param, type, tensorCache)
+ end
+ end
+ return self
+end
+