diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-08-16 19:27:34 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-08-16 19:27:34 +0300 |
commit | e079d80ea624cbd996e66aa3d6716af31f3a9aa5 (patch) | |
tree | 5244a52cde751fec8a6eb2abace9386a94727153 | |
parent | 1604e11c5b1a459464360750307b144ddd02f1ce (diff) | |
parent | 76b344546f714833e74045bcb793515ba4435e6b (diff) |
Merge pull request #245 from szagoruyko/bn-backward-assert
Prevent BatchNorm from backward in evaluate mode
-rw-r--r-- | BatchNormalization.lua | 1 |
1 files changed, 1 insertions, 0 deletions
diff --git a/BatchNormalization.lua b/BatchNormalization.lua index ac77e4f..5b4e9d4 100644 --- a/BatchNormalization.lua +++ b/BatchNormalization.lua @@ -81,6 +81,7 @@ function BatchNormalization:updateOutput(input) end local function backward(self,input,gradOutput, scale) + assert(self.train, 'cudnn.BatchNormalization doesnt support backward in evaluate, use nn') self.scaleT = self.scaleT or self.weight.new(1) -- this line forces this member to always be on CPU (needed for cudnn) self.scaleT = torch.type(self.weight) == 'torch.CudaDoubleTensor' |