diff options
author | Natalia Gimeshein <ngimelshein@nvidia.com> | 2016-10-05 23:57:42 +0300 |
---|---|---|
committer | Natalia Gimeshein <ngimelshein@nvidia.com> | 2016-10-05 23:57:42 +0300 |
commit | 942d7965c73d5d210807f48a95eeb4fdac69af5e (patch) | |
tree | 9cad35b6365ece2b8df0ae8bd63fe4e271c1cb9d | |
parent | ba9513c9f17580a6c2f75f6d499086cdfbc0b3d3 (diff) |
reset algo family on size change
-rw-r--r-- | find.lua | 16 |
1 files changed, 11 insertions, 5 deletions
@@ -153,12 +153,17 @@ end local finders = nil -- this resets algorithm cache for device + +local function setAlgoFamily() + return cudnn.benchmark + and (cudnn.useFindEx and FindExFamily or FindFamily) + or GetFamily +end + function find:resetAlgorithmCache() self.calculatedWorkspaceSize = {} self:calculateMaxWorkspaceSize() - self.algoFamily = cudnn.benchmark - and (cudnn.useFindEx and FindExFamily or FindFamily) - or GetFamily + self.algoFamily = setAlgoFamily() self.autotunerCache = {{}, {}, {}} end @@ -312,7 +317,9 @@ function find:setupAlgo(layer, findAPI_idx, algSearchMode, params) else cacheHit = '' cachedAlgo = {} - +--algo family might have changed, reset it + self.algoFamily = setAlgoFamily() + local API = algoFamilies[self.algoFamily][findAPI_idx] if self.algoFamily == FindExFamily then -- use clone for weights when looking for backward filter algo if findAPI_idx == BwdFilter then @@ -469,7 +476,6 @@ function find:prepare(layer, input_slice, output_slice) ..' -convStrideA' .. vals(layer.stride) .. ' ' .. cudnn.configmap(torch.type(layer.weight)) - layer:resetMode() layer.iteration = nil layer.input_slice = input_slice layer.output_slice = output_slice |