diff options
Diffstat (limited to 'src/layers/common.cc')
-rw-r--r-- | src/layers/common.cc | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/src/layers/common.cc b/src/layers/common.cc index e3d7f22d..75ba0c12 100644 --- a/src/layers/common.cc +++ b/src/layers/common.cc @@ -365,5 +365,31 @@ namespace ctranslate2 { _norm_op(_beta, _gamma, input, output); } + + Conv1D::Conv1D(const models::Model& model, + const std::string& scope, + dim_t stride, + dim_t padding, + dim_t dilation) + : _conv_op(stride, padding, dilation) + , _weight(model.get_variable(scope + "/weight")) + , _bias(model.get_variable_if_exists(scope + "/bias")) { + } + + DataType Conv1D::output_type() const { + return _weight.dtype(); + } + + dim_t Conv1D::output_size() const { + return _weight.dim(0); + } + + void Conv1D::operator()(const StorageView& input, StorageView& output) const { + if (_bias) + _conv_op(input, _weight, *_bias, output); + else + _conv_op(input, _weight, output); + } + } } |