diff options
author | Sergey Zagoruyko <zagoruyko2@gmail.com> | 2016-05-26 21:08:08 +0300 |
---|---|---|
committer | Sergey Zagoruyko <zagoruyko2@gmail.com> | 2016-05-26 21:08:08 +0300 |
commit | 3ee0de9d572f4b078a9d51a75ced19ac9b3dbe2c (patch) | |
tree | 13d9531767250c67085d0e8e3cc36f41497450c7 /convert.lua | |
parent | 9a5c2f02c7b0415f021d568caccd65002f6658ee (diff) |
use replace from nn
Diffstat (limited to 'convert.lua')
-rw-r--r-- | convert.lua | 15 |
1 files changed, 1 insertions, 14 deletions
diff --git a/convert.lua b/convert.lua index 638928b..9371e27 100644 --- a/convert.lua +++ b/convert.lua @@ -18,23 +18,10 @@ local layer_list = { 'VolumetricAveragePooling', } --- similar to nn.Module.apply --- goes over a net and recursively replaces modules --- using callback function -local function replace(self, callback) - local out = callback(self) - if self.modules then - for i, module in ipairs(self.modules) do - self.modules[i] = replace(module, callback) - end - end - return out -end - -- goes over a given net and converts all layers to dst backend -- for example: net = cudnn.convert(net, cudnn) function cudnn.convert(net, dst) - return replace(net, function(x) + return net:replace(function(x) local y = 0 local src = dst == nn and cudnn or nn local src_prefix = src == nn and 'nn.' or 'cudnn.' |