Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Leonard <nleonard@twitter.com>2017-05-26 21:33:15 +0300
committerNicholas Leonard <nleonard@twitter.com>2017-05-26 21:33:15 +0300
commit0863b44e61312292205f8fdf4f1b08cb282bf1f4 (patch)
tree68875e902c0c68894a7425411feeaf1e3a7280d6
parentdf1af9500a45f4deecd0f3f1f5020fe4789248ca (diff)
nn.ModuleCriterion
-rw-r--r--ModuleCriterion.lua44
-rw-r--r--doc/criterion.md15
-rwxr-xr-xinit.lua1
-rwxr-xr-xtest.lua18
4 files changed, 78 insertions, 0 deletions
diff --git a/ModuleCriterion.lua b/ModuleCriterion.lua
new file mode 100644
index 0000000..bfc79ef
--- /dev/null
+++ b/ModuleCriterion.lua
@@ -0,0 +1,44 @@
+local ModuleCriterion, parent = torch.class("nn.ModuleCriterion", "nn.Criterion")
+
+function ModuleCriterion:__init(criterion, inputModule, targetModule, castTarget)
+ self.inputModule = inputModule
+ self.targetModule = targetModule
+ self.castTarget = (castTarget == nil) and true or castTarget
+ if self.inputModule then
+ local params = self.inputModule:parameters()
+ if params and #params > 0 then
+ print"Warning: nn.ModuleCriterion doesn't support parameter updates"
+ end
+ end
+ self.criterion = criterion
+end
+
+function ModuleCriterion:updateOutput(input, target)
+ if self.inputModule then
+ self.input = self.inputModule:forward(input)
+ end
+ if self.targetModule then
+ self.target = self.targetModule:forward(target)
+ end
+ self.output = self.criterion:forward(self.input or input, self.target or target)
+ return self.output
+end
+
+function ModuleCriterion:updateGradInput(input, target)
+ self.gradInput = self.criterion:backward(self.input or input, self.target or target)
+ if self.inputModule then
+ self.gradInput = self.inputModule:backward(input, self.gradInput)
+ end
+ return self.gradInput
+end
+
+function ModuleCriterion:type(type, typecache)
+ if self.inputModule then
+ self.inputModule:type(type, typecache)
+ end
+ if self.castTarget and self.targetModule then
+ self.targetModule:type(type, typecache)
+ end
+ self.criterion:type(type, typecache)
+ return parent.type(self, type, typecache)
+end
diff --git a/doc/criterion.md b/doc/criterion.md
index a3e1b2e..000fcb7 100644
--- a/doc/criterion.md
+++ b/doc/criterion.md
@@ -29,6 +29,7 @@ target, they compute a gradient according to a given loss function.
* [`MultiCriterion`](#nn.MultiCriterion) : a weighted sum of other criterions each applied to the same input and target;
* [`ParallelCriterion`](#nn.ParallelCriterion) : a weighted sum of other criterions each applied to a different input and target;
* [`MarginRankingCriterion`](#nn.MarginRankingCriterion): ranks two inputs;
+ * [`ModuleCriterion`](#nn.ModuleCriterion) : adds an optional `inputModule` and `targetModule` before a decorated criterion;
<a name="nn.Criterion"></a>
## Criterion ##
@@ -877,3 +878,17 @@ for i = 1, 100 do
end
end
```
+
+<a name='nn.ModuleCriterion'></a>
+## ModuleCriterion ##
+
+```lua
+criterion = nn.ModuleCriterion(criterion [, inputModule, targetModule, castTarget])
+```
+
+This criterion decorates a `criterion` by allowing the `input` and `target` to be
+fed through an optional `inputModule` and `targetModule` before being passed to the
+`criterion`. The `inputModule` must not contain parameters as these would not be updated.
+
+When `castTarget = true` (the default), the `targetModule` is cast along with the `inputModule` and
+`criterion`. Otherwise, the `targetModule` isn't.
diff --git a/init.lua b/init.lua
index 503d2c2..a48a9b5 100755
--- a/init.lua
+++ b/init.lua
@@ -201,6 +201,7 @@ require('nn.BCECriterion')
require('nn.CrossEntropyCriterion')
require('nn.ParallelCriterion')
require('nn.DistanceRatioCriterion')
+require('nn.ModuleCriterion')
require('nn.PixelShuffle')
diff --git a/test.lua b/test.lua
index 67b9fd9..465ce8e 100755
--- a/test.lua
+++ b/test.lua
@@ -2175,7 +2175,25 @@ function nntest.MarginRankingCriterion()
local v = torch.rand(2, batch_size)
local t = torch.Tensor(batch_size):random(0,1):mul(2):add(-1)
criterionJacobianTest1DTable(crit,v,t)
+end
+
+function nntest.ModuleCriterion()
+ local input = torch.randn(8,4)
+ local target = torch.randn(8,4)
+ local inputModule = nn.Tanh()
+ local criterion = nn.MSECriterion()
+ local mc = nn.ModuleCriterion(criterion, inputModule)
+
+ local err = mc:forward(input, target)
+ local gradInput = mc:backward(input, target)
+
+ local output = inputModule:forward(input)
+ local err2 = criterion:forward(output, target)
+ local gradOutput = criterion:backward(output, target)
+ local gradInput2 = inputModule:backward(input, gradOutput)
+ mytester:assert(err == err2, "ModuleCriterion backward err")
+ mytester:assertTensorEq(gradInput, gradInput2, 0.000001, "ModuleCriterion backward err")
end
function nntest.MaskedSelect()