Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Belanger <david.b.belanger@gmail.com>2015-11-25 00:37:16 +0300
committerDavid Belanger <david.b.belanger@gmail.com>2015-11-25 00:37:16 +0300
commit7e4d57e26eb6b2c294d58522d08d22f3726c4551 (patch)
tree1726989870e450c7bcd0591755cbb400add6e91a /PullTable.lua
parent888c940fe953e12b7277f4b9b09761fc73818062 (diff)
fix infinite recursion in PullTable and PushTable when converting to different types
Diffstat (limited to 'PullTable.lua')
-rw-r--r--PullTable.lua17
1 files changed, 17 insertions, 0 deletions
diff --git a/PullTable.lua b/PullTable.lua
index 1fed379..1413a07 100644
--- a/PullTable.lua
+++ b/PullTable.lua
@@ -58,3 +58,20 @@ function PullTable:updateGradInput(inputTable, gradOutputTable)
end
return self.gradInput
end
+
+
+function PullTable:type(type, tensorCache)
+ assert(type, 'PullTable: must provide a type to convert to')
+
+ tensorCache = tensorCache or {}
+
+ -- find all tensors and convert them
+ for key,param in pairs(self) do
+ if(key ~= "_push") then
+ self[key] = nn.utils.recursiveType(param, type, tensorCache)
+ end
+ end
+
+ return self
+end
+