Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'hessian.lua')
-rw-r--r--hessian.lua62
1 files changed, 61 insertions, 1 deletions
diff --git a/hessian.lua b/hessian.lua
index c55e066..3d336fe 100644
--- a/hessian.lua
+++ b/hessian.lua
@@ -164,6 +164,24 @@ function nn.hessian.enable()
end
----------------------------------------------------------------------
+ -- WeightedMSECriterion
+ ----------------------------------------------------------------------
+ function nn.WeightedMSECriterion.updateDiagHessianInput(self,input,target)
+ return nn.MSECriterion.updateDiagHessianInput(self,input,target)
+ end
+
+ ----------------------------------------------------------------------
+ -- L1Cost
+ ----------------------------------------------------------------------
+ function nn.L1Cost.updateDiagHessianInput(self,input)
+ self.diagHessianInput = self.diagHessianInput or input.new()
+ self.diagHessianInput:resizeAs(input)
+ self.diagHessianInput:fill(1)
+ self.diagHessianInput[torch.eq(input,0)] = 0
+ return self.diagHessianInput
+ end
+
+ ----------------------------------------------------------------------
-- Linear
----------------------------------------------------------------------
function nn.Linear.updateDiagHessianInput(self, input, diagHessianOutput)
@@ -188,7 +206,7 @@ function nn.hessian.enable()
end
function nn.SpatialConvolution.accDiagHessianParameters(self, input, diagHessianOutput)
- accDiagHessianParameters(self,input, diagHessianOutput, {'gradWeight'}, {'diagHessianWeight'})
+ accDiagHessianParameters(self,input, diagHessianOutput, {'gradWeight','gradBias'}, {'diagHessianWeight','diagHessianBias'})
end
function nn.SpatialConvolution.initDiagHessianParameters(self)
@@ -196,6 +214,22 @@ function nn.hessian.enable()
end
----------------------------------------------------------------------
+ -- SpatialFullConvolution
+ ----------------------------------------------------------------------
+ function nn.SpatialFullConvolution.updateDiagHessianInput(self, input, diagHessianOutput)
+ updateDiagHessianInput(self, input, diagHessianOutput, {'weight'}, {'weightSq'})
+ return self.diagHessianInput
+ end
+
+ function nn.SpatialFullConvolution.accDiagHessianParameters(self, input, diagHessianOutput)
+ accDiagHessianParameters(self,input, diagHessianOutput, {'gradWeight'}, {'diagHessianWeight'})
+ end
+
+ function nn.SpatialFullConvolution.initDiagHessianParameters(self)
+ initDiagHessianParameters(self,{'gradWeight','gradBias'},{'diagHessianWeight','diagHessianBias'})
+ end
+
+ ----------------------------------------------------------------------
-- SpatialConvolutionMap
----------------------------------------------------------------------
function nn.SpatialConvolutionMap.updateDiagHessianInput(self, input, diagHessianOutput)
@@ -212,6 +246,22 @@ function nn.hessian.enable()
end
----------------------------------------------------------------------
+ -- SpatialFullConvolutionMap
+ ----------------------------------------------------------------------
+ function nn.SpatialFullConvolutionMap.updateDiagHessianInput(self, input, diagHessianOutput)
+ updateDiagHessianInput(self, input, diagHessianOutput, {'weight'}, {'weightSq'})
+ return self.diagHessianInput
+ end
+
+ function nn.SpatialFullConvolutionMap.accDiagHessianParameters(self, input, diagHessianOutput)
+ accDiagHessianParameters(self,input, diagHessianOutput, {'gradWeight','gradBias'}, {'diagHessianWeight','diagHessianBias'})
+ end
+
+ function nn.SpatialFullConvolutionMap.initDiagHessianParameters(self)
+ initDiagHessianParameters(self,{'gradWeight','gradBias'},{'diagHessianWeight','diagHessianBias'})
+ end
+
+----------------------------------------------------------------------
-- Tanh
----------------------------------------------------------------------
function nn.Tanh.updateDiagHessianInput(self, input, diagHessianOutput)
@@ -220,6 +270,16 @@ function nn.hessian.enable()
end
----------------------------------------------------------------------
+ -- TanhShrink
+ ----------------------------------------------------------------------
+ function nn.TanhShrink.updateDiagHessianInput(self, input, diagHessianOutput)
+ updateDiagHessianInputPointWise(self.tanh, input, diagHessianOutput)
+ self.diagHessianInput = self.diagHessianInput or input.new():resizeAs(input)
+ torch.add(self.diagHessianInput, self.tanh.diagHessianInput, diagHessianOutput)
+ return self.diagHessianInput
+ end
+
+ ----------------------------------------------------------------------
-- Square
----------------------------------------------------------------------
function nn.Square.updateDiagHessianInput(self, input, diagHessianOutput)