Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSergey Zagoruyko <zagoruyko2@gmail.com>2016-05-26 21:08:08 +0300
committerSergey Zagoruyko <zagoruyko2@gmail.com>2016-05-26 21:08:08 +0300
commit3ee0de9d572f4b078a9d51a75ced19ac9b3dbe2c (patch)
tree13d9531767250c67085d0e8e3cc36f41497450c7 /convert.lua
parent9a5c2f02c7b0415f021d568caccd65002f6658ee (diff)
use replace from nn
Diffstat (limited to 'convert.lua')
-rw-r--r--convert.lua15
1 files changed, 1 insertions, 14 deletions
diff --git a/convert.lua b/convert.lua
index 638928b..9371e27 100644
--- a/convert.lua
+++ b/convert.lua
@@ -18,23 +18,10 @@ local layer_list = {
'VolumetricAveragePooling',
}
--- similar to nn.Module.apply
--- goes over a net and recursively replaces modules
--- using callback function
-local function replace(self, callback)
- local out = callback(self)
- if self.modules then
- for i, module in ipairs(self.modules) do
- self.modules[i] = replace(module, callback)
- end
- end
- return out
-end
-
-- goes over a given net and converts all layers to dst backend
-- for example: net = cudnn.convert(net, cudnn)
function cudnn.convert(net, dst)
- return replace(net, function(x)
+ return net:replace(function(x)
local y = 0
local src = dst == nn and cudnn or nn
local src_prefix = src == nn and 'nn.' or 'cudnn.'