Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGaetanMarceauCaron <gaetan.marceau-caron@inria.fr>2016-04-15 17:37:19 +0300
committerGaetanMarceauCaron <gaetan.marceau-caron@inria.fr>2016-04-15 17:37:19 +0300
commit3b669b13d31cae16ea7a61d7eb1d1e7b8fb35e1c (patch)
treef32c84fbc7bf59a68c53922568808f1fcb7a0b25
parentf9cd545edb06bede2a3f5b98987a65fcad777c81 (diff)
Removing useless flags for OP metric
-rw-r--r--QDRiemaNNLinear.lua57
1 files changed, 23 insertions, 34 deletions
diff --git a/QDRiemaNNLinear.lua b/QDRiemaNNLinear.lua
index 8dc4793..961a467 100644
--- a/QDRiemaNNLinear.lua
+++ b/QDRiemaNNLinear.lua
@@ -18,46 +18,35 @@ function QDRiemaNNLinear:__init(inputSize, outputSize, gamma, qdFlag)
self.Mii = torch.Tensor(outputSize, inputSize)
if self.qdFlag then self.M0i = torch.Tensor(outputSize, inputSize) end
self.M00 = torch.Tensor(outputSize)
- self.accGradientFlag = true
- self.accMetricFlag = true
-end
-
-function QDRiemaNNLinear:setAccFlag(accGradientFlag,accMetricFlag)
- self.accGradientFlag = accGradientFlag
- self.accMetricFlag = accMetricFlag
end
function QDRiemaNNLinear:accGradParameters(input, gradOutput)
- if self.accGradientFlag then
- parent.accGradParameters(self,input,gradOutput)
- end
+ parent.accGradParameters(self,input,gradOutput)
- if self.accMetricFlag then
- local gradOutputSqT = torch.pow(gradOutput,2):t()
+ local gradOutputSqT = torch.pow(gradOutput,2):t()
+
+ if self.initMetric then
+ self.Mii:mm(gradOutputSqT,torch.pow(input,2))
+ self.M00:mv(gradOutputSqT,self.addBuffer)
+ if self.qdFlag then self.M0i:mm(gradOutputSqT,input) end
+ self.initMetric = false
+ else
+ self.Mii:mul(1.-self.gamma):addmm(self.gamma,gradOutputSqT,torch.pow(input,2))
+ if self.qdFlag then self.M0i:mul(1.-self.gamma):addmm(self.gamma,gradOutputSqT,input) end
+ self.M00:mul(1.-self.gamma):addmv(self.gamma,gradOutputSqT,self.addBuffer)
+ end
+
+ if self.qdFlag then
+ local numerator = torch.add(torch.cmul(self.gradWeight,self.M00:view(-1,1):expandAs(self.gradWeight)), -1.0, torch.cmul(self.M0i,self.gradBias:view(-1,1):expandAs(self.M0i)))
+ local denominator = torch.add(torch.cmul(self.Mii,self.M00:view(-1,1):expandAs(self.Mii)),-1.0,torch.pow(self.M0i,2)):clamp(self.matReg,1e25)
+ self.gradWeight:copy(numerator:cdiv(denominator))
- if self.initMetric then
- self.Mii:mm(gradOutputSqT,torch.pow(input,2))
- self.M00:mv(gradOutputSqT,self.addBuffer)
- if self.qdFlag then self.M0i:mm(gradOutputSqT,input) end
- self.initMetric = false
- else
- self.Mii:mul(1.-self.gamma):addmm(self.gamma,gradOutputSqT,torch.pow(input,2))
- if self.qdFlag then self.M0i:mul(1.-self.gamma):addmm(self.gamma,gradOutputSqT,input) end
- self.M00:mul(1.-self.gamma):addmv(self.gamma,gradOutputSqT,self.addBuffer)
- end
+ local temp = torch.cmul(self.M0i,self.gradWeight):sum(2)
+ self.gradBias:add(-1.,temp):cdiv(torch.add(self.M00,self.matReg))
- if self.qdFlag then
- local numerator = torch.add(torch.cmul(self.gradWeight,self.M00:view(-1,1):expandAs(self.gradWeight)), -1.0, torch.cmul(self.M0i,self.gradBias:view(-1,1):expandAs(self.M0i)))
- local denominator = torch.add(torch.cmul(self.Mii,self.M00:view(-1,1):expandAs(self.Mii)),-1.0,torch.pow(self.M0i,2)):clamp(self.matReg,1e25)
- self.gradWeight:copy(numerator:cdiv(denominator))
-
- local temp = torch.cmul(self.M0i,self.gradWeight):sum(2)
- self.gradBias:add(-1.,temp):cdiv(torch.add(self.M00,self.matReg))
-
- else
- self.gradWeight:cdiv(self.Mii:add(self.matReg))
- self.gradBias:cdiv(self.M00:add(self.matReg))
- end
+ else
+ self.gradWeight:cdiv(self.Mii:add(self.matReg))
+ self.gradBias:cdiv(self.M00:add(self.matReg))
end
end