Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-07-14 18:55:07 +0400
committernicholas-leonard <nick@nikopia.org>2014-07-14 18:55:07 +0400
commit7f0f53416f019bb980c9e922b72048ab5ea4fd77 (patch)
treec98565dee570bb097410b47102173db07635938a /MixtureTable.lua
parentf0a5f6c8926af4d3c068f1ffbf7290c3e7b21068 (diff)
MixtureTable type-cast unit tests
Diffstat (limited to 'MixtureTable.lua')
-rw-r--r--MixtureTable.lua12
1 files changed, 7 insertions, 5 deletions
diff --git a/MixtureTable.lua b/MixtureTable.lua
index b107618..6111a99 100644
--- a/MixtureTable.lua
+++ b/MixtureTable.lua
@@ -150,11 +150,6 @@ end
function MixtureTable:type(type)
self.output = self.output:type(type)
self.gradInput[1] = self.gradInput[1]:type(type)
- if torch.type(self.gradInput[2]) == 'table' then
- for i,expertGradInput in ipairs(self.gradInput[2]) do
- self.gradInput[2][i] = expertGradInput:type(type)
- end
- end
self._gaterView = self._gaterView:type(type)
self._expert = self._expert:type(type)
self._expertView = self._expertView:type(type)
@@ -162,4 +157,11 @@ function MixtureTable:type(type)
self._gradInput = self._gradInput:type(type)
self._expert2 = self._expert2:type(type)
self._expertView2 = self._expertView2:type(type)
+ if torch.type(self.gradInput[2]) == 'table' then
+ for i,expertGradInput in ipairs(self.gradInput[2]) do
+ self.gradInput[2][i] = expertGradInput:type(type)
+ end
+ else
+ self.gradInput[2] = self._gradInput
+ end
end