diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-07-14 18:55:07 +0400 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2014-07-14 18:55:07 +0400 |
commit | 7f0f53416f019bb980c9e922b72048ab5ea4fd77 (patch) | |
tree | c98565dee570bb097410b47102173db07635938a /MixtureTable.lua | |
parent | f0a5f6c8926af4d3c068f1ffbf7290c3e7b21068 (diff) |
MixtureTable type-cast unit tests
Diffstat (limited to 'MixtureTable.lua')
-rw-r--r-- | MixtureTable.lua | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/MixtureTable.lua b/MixtureTable.lua index b107618..6111a99 100644 --- a/MixtureTable.lua +++ b/MixtureTable.lua @@ -150,11 +150,6 @@ end function MixtureTable:type(type) self.output = self.output:type(type) self.gradInput[1] = self.gradInput[1]:type(type) - if torch.type(self.gradInput[2]) == 'table' then - for i,expertGradInput in ipairs(self.gradInput[2]) do - self.gradInput[2][i] = expertGradInput:type(type) - end - end self._gaterView = self._gaterView:type(type) self._expert = self._expert:type(type) self._expertView = self._expertView:type(type) @@ -162,4 +157,11 @@ function MixtureTable:type(type) self._gradInput = self._gradInput:type(type) self._expert2 = self._expert2:type(type) self._expertView2 = self._expertView2:type(type) + if torch.type(self.gradInput[2]) == 'table' then + for i,expertGradInput in ipairs(self.gradInput[2]) do + self.gradInput[2][i] = expertGradInput:type(type) + end + else + self.gradInput[2] = self._gradInput + end end |