Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/CTranslate2.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/layers/common.cc')
-rw-r--r--src/layers/common.cc26
1 files changed, 26 insertions, 0 deletions
diff --git a/src/layers/common.cc b/src/layers/common.cc
index e3d7f22d..75ba0c12 100644
--- a/src/layers/common.cc
+++ b/src/layers/common.cc
@@ -365,5 +365,31 @@ namespace ctranslate2 {
_norm_op(_beta, _gamma, input, output);
}
+
+ Conv1D::Conv1D(const models::Model& model,
+ const std::string& scope,
+ dim_t stride,
+ dim_t padding,
+ dim_t dilation)
+ : _conv_op(stride, padding, dilation)
+ , _weight(model.get_variable(scope + "/weight"))
+ , _bias(model.get_variable_if_exists(scope + "/bias")) {
+ }
+
+ DataType Conv1D::output_type() const {
+ return _weight.dtype();
+ }
+
+ dim_t Conv1D::output_size() const {
+ return _weight.dim(0);
+ }
+
+ void Conv1D::operator()(const StorageView& input, StorageView& output) const {
+ if (_bias)
+ _conv_op(input, _weight, *_bias, output);
+ else
+ _conv_op(input, _weight, output);
+ }
+
}
}