diff options
author | Natalia Gimeshein <ngimelshein@nvidia.com> | 2016-10-08 09:27:04 +0300 |
---|---|---|
committer | Natalia Gimeshein <ngimelshein@nvidia.com> | 2016-10-08 09:27:04 +0300 |
commit | 1f358d7bed0bb7c165dd1e5dd7829da35277be5e (patch) | |
tree | 9df0b6641625f680c1d6d3adb5c1486db7d44225 | |
parent | 068a0d2a85a3090d324656a2d7cf238952e8a91f (diff) |
clone output tensor in FindEx
-rw-r--r-- | find.lua | 8 |
1 files changed, 4 insertions, 4 deletions
@@ -321,10 +321,9 @@ function find:setupAlgo(layer, findAPI_idx, algSearchMode, params) self.algoFamily = setAlgoFamily() local API = algoFamilies[self.algoFamily][findAPI_idx] if self.algoFamily == FindExFamily then - -- use clone for weights when looking for backward filter algo - if findAPI_idx == BwdFilter then - params[7] = params[7]:clone() - end + -- clone output tensor + local paramstmp = params[7] + params[7] = paramstmp:clone() -- temporarily set WS size to the max self:calculateMaxWorkspaceSize() cudnn.setSharedWorkspaceSize(self.maxWorkspaceSize) @@ -345,6 +344,7 @@ function find:setupAlgo(layer, findAPI_idx, algSearchMode, params) cudnn.getHandle(), params[1], params[2]:data(), params[3], params[4]:data(), layer.convDesc[0], params[6], params[7]:data(), nAlgos, numPerfResults, perfResults, tempWorkspace, tempWorkspaceSize) + params[7]=paramstmp else if self.algoFamily == FindFamily then ret = call(layer, API, |