Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNatalia Gimeshein <ngimelshein@nvidia.com>2016-10-05 23:57:42 +0300
committerNatalia Gimeshein <ngimelshein@nvidia.com>2016-10-05 23:57:42 +0300
commit942d7965c73d5d210807f48a95eeb4fdac69af5e (patch)
tree9cad35b6365ece2b8df0ae8bd63fe4e271c1cb9d
parentba9513c9f17580a6c2f75f6d499086cdfbc0b3d3 (diff)
reset algo family on size change
-rw-r--r--find.lua16
1 files changed, 11 insertions, 5 deletions
diff --git a/find.lua b/find.lua
index 3873efa..81af459 100644
--- a/find.lua
+++ b/find.lua
@@ -153,12 +153,17 @@ end
local finders = nil
-- this resets algorithm cache for device
+
+local function setAlgoFamily()
+ return cudnn.benchmark
+ and (cudnn.useFindEx and FindExFamily or FindFamily)
+ or GetFamily
+end
+
function find:resetAlgorithmCache()
self.calculatedWorkspaceSize = {}
self:calculateMaxWorkspaceSize()
- self.algoFamily = cudnn.benchmark
- and (cudnn.useFindEx and FindExFamily or FindFamily)
- or GetFamily
+ self.algoFamily = setAlgoFamily()
self.autotunerCache = {{}, {}, {}}
end
@@ -312,7 +317,9 @@ function find:setupAlgo(layer, findAPI_idx, algSearchMode, params)
else
cacheHit = ''
cachedAlgo = {}
-
+--algo family might have changed, reset it
+ self.algoFamily = setAlgoFamily()
+ local API = algoFamilies[self.algoFamily][findAPI_idx]
if self.algoFamily == FindExFamily then
-- use clone for weights when looking for backward filter algo
if findAPI_idx == BwdFilter then
@@ -469,7 +476,6 @@ function find:prepare(layer, input_slice, output_slice)
..' -convStrideA' .. vals(layer.stride)
.. ' ' .. cudnn.configmap(torch.type(layer.weight))
- layer:resetMode()
layer.iteration = nil
layer.input_slice = input_slice
layer.output_slice = output_slice