Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNatalia Gimeshein <ngimelshein@nvidia.com>2016-10-08 09:27:04 +0300
committerNatalia Gimeshein <ngimelshein@nvidia.com>2016-10-08 09:27:04 +0300
commit1f358d7bed0bb7c165dd1e5dd7829da35277be5e (patch)
tree9df0b6641625f680c1d6d3adb5c1486db7d44225
parent068a0d2a85a3090d324656a2d7cf238952e8a91f (diff)
clone output tensor in FindEx
-rw-r--r--find.lua8
1 files changed, 4 insertions, 4 deletions
diff --git a/find.lua b/find.lua
index 81af459..af6481a 100644
--- a/find.lua
+++ b/find.lua
@@ -321,10 +321,9 @@ function find:setupAlgo(layer, findAPI_idx, algSearchMode, params)
self.algoFamily = setAlgoFamily()
local API = algoFamilies[self.algoFamily][findAPI_idx]
if self.algoFamily == FindExFamily then
- -- use clone for weights when looking for backward filter algo
- if findAPI_idx == BwdFilter then
- params[7] = params[7]:clone()
- end
+ -- clone output tensor
+ local paramstmp = params[7]
+ params[7] = paramstmp:clone()
-- temporarily set WS size to the max
self:calculateMaxWorkspaceSize()
cudnn.setSharedWorkspaceSize(self.maxWorkspaceSize)
@@ -345,6 +344,7 @@ function find:setupAlgo(layer, findAPI_idx, algSearchMode, params)
cudnn.getHandle(),
params[1], params[2]:data(), params[3], params[4]:data(), layer.convDesc[0], params[6], params[7]:data(),
nAlgos, numPerfResults, perfResults, tempWorkspace, tempWorkspaceSize)
+ params[7]=paramstmp
else
if self.algoFamily == FindFamily then
ret = call(layer, API,