diff options
Diffstat (limited to 'hessian.lua')
-rw-r--r-- | hessian.lua | 62 |
1 files changed, 61 insertions, 1 deletions
diff --git a/hessian.lua b/hessian.lua index c55e066..3d336fe 100644 --- a/hessian.lua +++ b/hessian.lua @@ -164,6 +164,24 @@ function nn.hessian.enable() end ---------------------------------------------------------------------- + -- WeightedMSECriterion + ---------------------------------------------------------------------- + function nn.WeightedMSECriterion.updateDiagHessianInput(self,input,target) + return nn.MSECriterion.updateDiagHessianInput(self,input,target) + end + + ---------------------------------------------------------------------- + -- L1Cost + ---------------------------------------------------------------------- + function nn.L1Cost.updateDiagHessianInput(self,input) + self.diagHessianInput = self.diagHessianInput or input.new() + self.diagHessianInput:resizeAs(input) + self.diagHessianInput:fill(1) + self.diagHessianInput[torch.eq(input,0)] = 0 + return self.diagHessianInput + end + + ---------------------------------------------------------------------- -- Linear ---------------------------------------------------------------------- function nn.Linear.updateDiagHessianInput(self, input, diagHessianOutput) @@ -188,7 +206,7 @@ function nn.hessian.enable() end function nn.SpatialConvolution.accDiagHessianParameters(self, input, diagHessianOutput) - accDiagHessianParameters(self,input, diagHessianOutput, {'gradWeight'}, {'diagHessianWeight'}) + accDiagHessianParameters(self,input, diagHessianOutput, {'gradWeight','gradBias'}, {'diagHessianWeight','diagHessianBias'}) end function nn.SpatialConvolution.initDiagHessianParameters(self) @@ -196,6 +214,22 @@ function nn.hessian.enable() end ---------------------------------------------------------------------- + -- SpatialFullConvolution + ---------------------------------------------------------------------- + function nn.SpatialFullConvolution.updateDiagHessianInput(self, input, diagHessianOutput) + updateDiagHessianInput(self, input, diagHessianOutput, {'weight'}, {'weightSq'}) + return self.diagHessianInput + end + + function nn.SpatialFullConvolution.accDiagHessianParameters(self, input, diagHessianOutput) + accDiagHessianParameters(self,input, diagHessianOutput, {'gradWeight'}, {'diagHessianWeight'}) + end + + function nn.SpatialFullConvolution.initDiagHessianParameters(self) + initDiagHessianParameters(self,{'gradWeight','gradBias'},{'diagHessianWeight','diagHessianBias'}) + end + + ---------------------------------------------------------------------- -- SpatialConvolutionMap ---------------------------------------------------------------------- function nn.SpatialConvolutionMap.updateDiagHessianInput(self, input, diagHessianOutput) @@ -212,6 +246,22 @@ function nn.hessian.enable() end ---------------------------------------------------------------------- + -- SpatialFullConvolutionMap + ---------------------------------------------------------------------- + function nn.SpatialFullConvolutionMap.updateDiagHessianInput(self, input, diagHessianOutput) + updateDiagHessianInput(self, input, diagHessianOutput, {'weight'}, {'weightSq'}) + return self.diagHessianInput + end + + function nn.SpatialFullConvolutionMap.accDiagHessianParameters(self, input, diagHessianOutput) + accDiagHessianParameters(self,input, diagHessianOutput, {'gradWeight','gradBias'}, {'diagHessianWeight','diagHessianBias'}) + end + + function nn.SpatialFullConvolutionMap.initDiagHessianParameters(self) + initDiagHessianParameters(self,{'gradWeight','gradBias'},{'diagHessianWeight','diagHessianBias'}) + end + +---------------------------------------------------------------------- -- Tanh ---------------------------------------------------------------------- function nn.Tanh.updateDiagHessianInput(self, input, diagHessianOutput) @@ -220,6 +270,16 @@ function nn.hessian.enable() end ---------------------------------------------------------------------- + -- TanhShrink + ---------------------------------------------------------------------- + function nn.TanhShrink.updateDiagHessianInput(self, input, diagHessianOutput) + updateDiagHessianInputPointWise(self.tanh, input, diagHessianOutput) + self.diagHessianInput = self.diagHessianInput or input.new():resizeAs(input) + torch.add(self.diagHessianInput, self.tanh.diagHessianInput, diagHessianOutput) + return self.diagHessianInput + end + + ---------------------------------------------------------------------- -- Square ---------------------------------------------------------------------- function nn.Square.updateDiagHessianInput(self, input, diagHessianOutput) |