diff options
author | Nicholas Leonard <nleonard@twitter.com> | 2017-05-26 21:33:15 +0300 |
---|---|---|
committer | Nicholas Leonard <nleonard@twitter.com> | 2017-05-26 21:33:15 +0300 |
commit | 0863b44e61312292205f8fdf4f1b08cb282bf1f4 (patch) | |
tree | 68875e902c0c68894a7425411feeaf1e3a7280d6 | |
parent | df1af9500a45f4deecd0f3f1f5020fe4789248ca (diff) |
nn.ModuleCriterion
-rw-r--r-- | ModuleCriterion.lua | 44 | ||||
-rw-r--r-- | doc/criterion.md | 15 | ||||
-rwxr-xr-x | init.lua | 1 | ||||
-rwxr-xr-x | test.lua | 18 |
4 files changed, 78 insertions, 0 deletions
diff --git a/ModuleCriterion.lua b/ModuleCriterion.lua new file mode 100644 index 0000000..bfc79ef --- /dev/null +++ b/ModuleCriterion.lua @@ -0,0 +1,44 @@ +local ModuleCriterion, parent = torch.class("nn.ModuleCriterion", "nn.Criterion") + +function ModuleCriterion:__init(criterion, inputModule, targetModule, castTarget) + self.inputModule = inputModule + self.targetModule = targetModule + self.castTarget = (castTarget == nil) and true or castTarget + if self.inputModule then + local params = self.inputModule:parameters() + if params and #params > 0 then + print"Warning: nn.ModuleCriterion doesn't support parameter updates" + end + end + self.criterion = criterion +end + +function ModuleCriterion:updateOutput(input, target) + if self.inputModule then + self.input = self.inputModule:forward(input) + end + if self.targetModule then + self.target = self.targetModule:forward(target) + end + self.output = self.criterion:forward(self.input or input, self.target or target) + return self.output +end + +function ModuleCriterion:updateGradInput(input, target) + self.gradInput = self.criterion:backward(self.input or input, self.target or target) + if self.inputModule then + self.gradInput = self.inputModule:backward(input, self.gradInput) + end + return self.gradInput +end + +function ModuleCriterion:type(type, typecache) + if self.inputModule then + self.inputModule:type(type, typecache) + end + if self.castTarget and self.targetModule then + self.targetModule:type(type, typecache) + end + self.criterion:type(type, typecache) + return parent.type(self, type, typecache) +end diff --git a/doc/criterion.md b/doc/criterion.md index a3e1b2e..000fcb7 100644 --- a/doc/criterion.md +++ b/doc/criterion.md @@ -29,6 +29,7 @@ target, they compute a gradient according to a given loss function. * [`MultiCriterion`](#nn.MultiCriterion) : a weighted sum of other criterions each applied to the same input and target; * [`ParallelCriterion`](#nn.ParallelCriterion) : a weighted sum of other criterions each applied to a different input and target; * [`MarginRankingCriterion`](#nn.MarginRankingCriterion): ranks two inputs; + * [`ModuleCriterion`](#nn.ModuleCriterion) : adds an optional `inputModule` and `targetModule` before a decorated criterion; <a name="nn.Criterion"></a> ## Criterion ## @@ -877,3 +878,17 @@ for i = 1, 100 do end end ``` + +<a name='nn.ModuleCriterion'></a> +## ModuleCriterion ## + +```lua +criterion = nn.ModuleCriterion(criterion [, inputModule, targetModule, castTarget]) +``` + +This criterion decorates a `criterion` by allowing the `input` and `target` to be +fed through an optional `inputModule` and `targetModule` before being passed to the +`criterion`. The `inputModule` must not contain parameters as these would not be updated. + +When `castTarget = true` (the default), the `targetModule` is cast along with the `inputModule` and +`criterion`. Otherwise, the `targetModule` isn't. @@ -201,6 +201,7 @@ require('nn.BCECriterion') require('nn.CrossEntropyCriterion') require('nn.ParallelCriterion') require('nn.DistanceRatioCriterion') +require('nn.ModuleCriterion') require('nn.PixelShuffle') @@ -2175,7 +2175,25 @@ function nntest.MarginRankingCriterion() local v = torch.rand(2, batch_size) local t = torch.Tensor(batch_size):random(0,1):mul(2):add(-1) criterionJacobianTest1DTable(crit,v,t) +end + +function nntest.ModuleCriterion() + local input = torch.randn(8,4) + local target = torch.randn(8,4) + local inputModule = nn.Tanh() + local criterion = nn.MSECriterion() + local mc = nn.ModuleCriterion(criterion, inputModule) + + local err = mc:forward(input, target) + local gradInput = mc:backward(input, target) + + local output = inputModule:forward(input) + local err2 = criterion:forward(output, target) + local gradOutput = criterion:backward(output, target) + local gradInput2 = inputModule:backward(input, gradOutput) + mytester:assert(err == err2, "ModuleCriterion backward err") + mytester:assertTensorEq(gradInput, gradInput2, 0.000001, "ModuleCriterion backward err") end function nntest.MaskedSelect() |