diff options
author | GaetanMarceauCaron <gaetan.marceau-caron@inria.fr> | 2016-04-15 17:37:19 +0300 |
---|---|---|
committer | GaetanMarceauCaron <gaetan.marceau-caron@inria.fr> | 2016-04-15 17:37:19 +0300 |
commit | 3b669b13d31cae16ea7a61d7eb1d1e7b8fb35e1c (patch) | |
tree | f32c84fbc7bf59a68c53922568808f1fcb7a0b25 | |
parent | f9cd545edb06bede2a3f5b98987a65fcad777c81 (diff) |
Removing useless flags for OP metric
-rw-r--r-- | QDRiemaNNLinear.lua | 57 |
1 files changed, 23 insertions, 34 deletions
diff --git a/QDRiemaNNLinear.lua b/QDRiemaNNLinear.lua index 8dc4793..961a467 100644 --- a/QDRiemaNNLinear.lua +++ b/QDRiemaNNLinear.lua @@ -18,46 +18,35 @@ function QDRiemaNNLinear:__init(inputSize, outputSize, gamma, qdFlag) self.Mii = torch.Tensor(outputSize, inputSize) if self.qdFlag then self.M0i = torch.Tensor(outputSize, inputSize) end self.M00 = torch.Tensor(outputSize) - self.accGradientFlag = true - self.accMetricFlag = true -end - -function QDRiemaNNLinear:setAccFlag(accGradientFlag,accMetricFlag) - self.accGradientFlag = accGradientFlag - self.accMetricFlag = accMetricFlag end function QDRiemaNNLinear:accGradParameters(input, gradOutput) - if self.accGradientFlag then - parent.accGradParameters(self,input,gradOutput) - end + parent.accGradParameters(self,input,gradOutput) - if self.accMetricFlag then - local gradOutputSqT = torch.pow(gradOutput,2):t() + local gradOutputSqT = torch.pow(gradOutput,2):t() + + if self.initMetric then + self.Mii:mm(gradOutputSqT,torch.pow(input,2)) + self.M00:mv(gradOutputSqT,self.addBuffer) + if self.qdFlag then self.M0i:mm(gradOutputSqT,input) end + self.initMetric = false + else + self.Mii:mul(1.-self.gamma):addmm(self.gamma,gradOutputSqT,torch.pow(input,2)) + if self.qdFlag then self.M0i:mul(1.-self.gamma):addmm(self.gamma,gradOutputSqT,input) end + self.M00:mul(1.-self.gamma):addmv(self.gamma,gradOutputSqT,self.addBuffer) + end + + if self.qdFlag then + local numerator = torch.add(torch.cmul(self.gradWeight,self.M00:view(-1,1):expandAs(self.gradWeight)), -1.0, torch.cmul(self.M0i,self.gradBias:view(-1,1):expandAs(self.M0i))) + local denominator = torch.add(torch.cmul(self.Mii,self.M00:view(-1,1):expandAs(self.Mii)),-1.0,torch.pow(self.M0i,2)):clamp(self.matReg,1e25) + self.gradWeight:copy(numerator:cdiv(denominator)) - if self.initMetric then - self.Mii:mm(gradOutputSqT,torch.pow(input,2)) - self.M00:mv(gradOutputSqT,self.addBuffer) - if self.qdFlag then self.M0i:mm(gradOutputSqT,input) end - self.initMetric = false - else - self.Mii:mul(1.-self.gamma):addmm(self.gamma,gradOutputSqT,torch.pow(input,2)) - if self.qdFlag then self.M0i:mul(1.-self.gamma):addmm(self.gamma,gradOutputSqT,input) end - self.M00:mul(1.-self.gamma):addmv(self.gamma,gradOutputSqT,self.addBuffer) - end + local temp = torch.cmul(self.M0i,self.gradWeight):sum(2) + self.gradBias:add(-1.,temp):cdiv(torch.add(self.M00,self.matReg)) - if self.qdFlag then - local numerator = torch.add(torch.cmul(self.gradWeight,self.M00:view(-1,1):expandAs(self.gradWeight)), -1.0, torch.cmul(self.M0i,self.gradBias:view(-1,1):expandAs(self.M0i))) - local denominator = torch.add(torch.cmul(self.Mii,self.M00:view(-1,1):expandAs(self.Mii)),-1.0,torch.pow(self.M0i,2)):clamp(self.matReg,1e25) - self.gradWeight:copy(numerator:cdiv(denominator)) - - local temp = torch.cmul(self.M0i,self.gradWeight):sum(2) - self.gradBias:add(-1.,temp):cdiv(torch.add(self.M00,self.matReg)) - - else - self.gradWeight:cdiv(self.Mii:add(self.matReg)) - self.gradBias:cdiv(self.M00:add(self.matReg)) - end + else + self.gradWeight:cdiv(self.Mii:add(self.matReg)) + self.gradBias:cdiv(self.M00:add(self.matReg)) end end |