Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CMakeLists.txt2
-rwxr-xr-xscripts/amunmt_server.py49
-rw-r--r--src/3rd_party/CMakeLists.txt2
-rw-r--r--src/CMakeLists.txt106
-rw-r--r--src/common/god.cpp10
-rw-r--r--src/python/amunmt.cpp73
-rw-r--r--src/python/test.py13
7 files changed, 198 insertions, 57 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index b90df98c..587691d2 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -25,7 +25,7 @@ else(CUDA_FOUND)
endif(CUDA_FOUND)
endif(NOCUDA)
-find_package(Boost COMPONENTS system filesystem program_options timer iostreams thread python)
+find_package(Boost COMPONENTS system filesystem program_options timer iostreams python thread)
if(Boost_FOUND)
include_directories(${Boost_INCLUDE_DIRS})
set(EXT_LIBS ${EXT_LIBS} ${Boost_LIBRARIES})
diff --git a/scripts/amunmt_server.py b/scripts/amunmt_server.py
new file mode 100755
index 00000000..b82e5a48
--- /dev/null
+++ b/scripts/amunmt_server.py
@@ -0,0 +1,49 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import sys
+import os
+import argparse
+
+sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../build/src')
+import libamunmt as nmt
+
+from bottle import request, Bottle, abort
+
+app = Bottle()
+
+
+@app.route('/translate')
+def handle_websocket():
+ wsock = request.environ.get('wsgi.websocket')
+ if not wsock:
+ abort(400, 'Expected WebSocket request.')
+
+ while True:
+ try:
+ message = wsock.receive()
+ if message is not None:
+ trans = nmt.translate(message.split('\n'))
+ wsock.send('\n'.join(trans))
+ except WebSocketError:
+ break
+
+
+def parse_args():
+ """ parse command arguments """
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-c", dest="config")
+ parser.add_argument('-p', dest="port", default=8080, type=int)
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ nmt.init("-c {}".format(args.config))
+
+ from gevent.pywsgi import WSGIServer
+ from geventwebsocket import WebSocketError
+ from geventwebsocket.handler import WebSocketHandler
+ server = WSGIServer(("0.0.0.0", args.port), app,
+ handler_class=WebSocketHandler)
+ server.serve_forever()
diff --git a/src/3rd_party/CMakeLists.txt b/src/3rd_party/CMakeLists.txt
index 9601fb95..9130d20d 100644
--- a/src/3rd_party/CMakeLists.txt
+++ b/src/3rd_party/CMakeLists.txt
@@ -3,6 +3,6 @@ include_directories(.)
add_subdirectory(yaml-cpp)
-add_library(libcommon OBJECT
+add_library(libcnpy OBJECT
cnpy/cnpy.cpp
)
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index f30e2a95..fb192c58 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -2,40 +2,53 @@
include_directories(.)
include_directories(3rd_party)
-add_subdirectory(3rd_party)
-
-
-
-#add_library(librescorer OBJECT
-# rescorer/nbest.cpp
-#)
-
+add_library(cpumode OBJECT
+ cpu/mblas/matrix.cpp
+ cpu/mblas/phoenix_functions.cpp
+ cpu/dl4mt/decoder.cpp
+ cpu/dl4mt/encoder.cpp
+ cpu/dl4mt/gru.cpp
+ cpu/dl4mt/model.cpp
+ cpu/decoder/encoder_decoder.cpp
+)
-if(CUDA_FOUND)
-cuda_add_executable(
- amun
- common/decoder_main.cpp
+add_library(libcommon OBJECT
common/config.cpp
common/exception.cpp
- common/loader_factory.cpp
- common/logging.cpp
- common/vocab.cpp
- common/utils.cpp
common/god.cpp
common/history.cpp
common/loader.cpp
+ common/loader_factory.cpp
+ common/logging.cpp
common/printer.cpp
common/scorer.cpp
common/search.cpp
common/sentence.cpp
common/processor/bpe.cpp
- cpu/mblas/matrix.cpp
- cpu/mblas/phoenix_functions.cpp
- cpu/dl4mt/decoder.cpp
- cpu/dl4mt/encoder.cpp
- cpu/dl4mt/gru.cpp
- cpu/dl4mt/model.cpp
- cpu/decoder/encoder_decoder.cpp
+ common/utils.cpp
+ common/vocab.cpp
+)
+
+if(CUDA_FOUND)
+
+cuda_add_executable(
+ amun
+ common/decoder_main.cpp
+ common/loader_factory.cu
+ gpu/decoder/ape_penalty.cu
+ gpu/decoder/encoder_decoder.cu
+ gpu/dl4mt/encoder.cu
+ gpu/dl4mt/gru.cu
+ gpu/mblas/matrix.cu
+ gpu/npz_converter.cu
+ $<TARGET_OBJECTS:libcommon>
+ $<TARGET_OBJECTS:cpumode>
+ $<TARGET_OBJECTS:libyaml-cpp>
+ $<TARGET_OBJECTS:libcnpy>
+)
+
+cuda_add_library(amunmt SHARED
+ python/amunmt.cpp
common/loader_factory.cu
gpu/decoder/ape_penalty.cu
gpu/decoder/encoder_decoder.cu
@@ -44,50 +57,32 @@ cuda_add_executable(
gpu/dl4mt/gru.cu
gpu/npz_converter.cu
$<TARGET_OBJECTS:libcommon>
+ $<TARGET_OBJECTS:libcnpy>
+ $<TARGET_OBJECTS:cpumode>
$<TARGET_OBJECTS:libyaml-cpp>
)
+
else(CUDA_FOUND)
+
add_executable(
amun
common/decoder_main.cpp
- common/config.cpp
- common/exception.cpp
- common/loader_factory.cpp
- common/logging.cpp
- common/vocab.cpp
- common/utils.cpp
- common/god.cpp
- common/history.cpp
- common/loader.cpp
- common/printer.cpp
- common/scorer.cpp
- common/search.cpp
- common/sentence.cpp
- common/processor/bpe.cpp
- cpu/mblas/matrix.cpp
- cpu/mblas/phoenix_functions.cpp
- cpu/dl4mt/decoder.cpp
- cpu/dl4mt/encoder.cpp
- cpu/dl4mt/gru.cpp
- cpu/dl4mt/model.cpp
- cpu/decoder/encoder_decoder.cpp
+ $<TARGET_OBJECTS:libcnpy>
+ $<TARGET_OBJECTS:cpumode>
+ $<TARGET_OBJECTS:libcommon>
+ $<TARGET_OBJECTS:libyaml-cpp>
+)
+add_library(amunmt SHARED
+ python/amunmt.cpp
+ $<TARGET_OBJECTS:libcnpy>
+ $<TARGET_OBJECTS:cpumode>
$<TARGET_OBJECTS:libcommon>
$<TARGET_OBJECTS:libyaml-cpp>
)
endif(CUDA_FOUND)
-#cuda_add_executable(
-# rescorer
-# rescorer/rescorer_main.cu
-# mblas/matrix.cu
-# dl4mt/gru.cu
-# $<TARGET_OBJECTS:librescorer>
-# $<TARGET_OBJECTS:libcommon>
-# $<TARGET_OBJECTS:libyaml-cpp>
-#)
-
-foreach(exec amun)
+foreach(exec amun amunmt)
if(CUDA_FOUND)
target_link_libraries(${exec} ${EXT_LIBS} cuda)
cuda_add_cublas_to_target(${exec})
@@ -98,3 +93,4 @@ foreach(exec amun)
endforeach(exec)
add_subdirectory(bpe)
+add_subdirectory(3rd_party)
diff --git a/src/common/god.cpp b/src/common/god.cpp
index 1b305c0f..db54fb34 100644
--- a/src/common/god.cpp
+++ b/src/common/god.cpp
@@ -24,6 +24,16 @@ God::~God()
}
}
+God& God::Init(const std::string& options) {
+ std::vector<std::string> args = boost::program_options::split_unix(options);
+ int argc = args.size() + 1;
+ char* argv[argc];
+ argv[0] = const_cast<char*>("bogus");
+ for(int i = 1; i < argc; i++)
+ argv[i] = const_cast<char*>(args[i-1].c_str());
+ return Init(argc, argv);
+}
+
God& God::Init(int argc, char** argv) {
return Summon().NonStaticInit(argc, argv);
}
diff --git a/src/python/amunmt.cpp b/src/python/amunmt.cpp
new file mode 100644
index 00000000..9dbdd37c
--- /dev/null
+++ b/src/python/amunmt.cpp
@@ -0,0 +1,73 @@
+#include <cstdlib>
+#include <iostream>
+#include <string>
+#include <boost/timer/timer.hpp>
+#include <boost/thread/tss.hpp>
+#include <boost/python.hpp>
+
+#include "common/god.h"
+#include "common/logging.h"
+#include "common/threadpool.h"
+#include "common/search.h"
+#include "common/printer.h"
+#include "common/sentence.h"
+
+History TranslationTask(const std::string& in, size_t taskCounter) {
+ #ifdef __APPLE__
+ static boost::thread_specific_ptr<Search> s_search;
+ Search *search = s_search.get();
+
+ if(search == NULL) {
+ LOG(info) << "Created Search for thread " << std::this_thread::get_id();
+ search = new Search(taskCounter);
+ s_search.reset(search);
+ }
+ #else
+ thread_local std::unique_ptr<Search> search;
+
+ if(!search) {
+ LOG(info) << "Created Search for thread " << std::this_thread::get_id();
+ search.reset(new Search(taskCounter));
+ }
+ #endif
+
+ return search->Decode(Sentence(taskCounter, in));
+}
+
+void init(const std::string& options) {
+ God::Init(options);
+}
+
+boost::python::list translate(boost::python::list& in) {
+ size_t threadCount = God::Get<size_t>("threads");
+ LOG(info) << "Setting number of threads to " << threadCount;
+
+ ThreadPool pool(threadCount);
+ std::vector<std::future<History>> results;
+
+ boost::python::list output;
+ for(int i = 0; i < boost::python::len(in); ++i) {
+ std::string s = boost::python::extract<std::string>(boost::python::object(in[i]));
+ results.emplace_back(
+ pool.enqueue(
+ [=]{ return TranslationTask(s, i); }
+ )
+ );
+ }
+
+ size_t lineCounter = 0;
+
+ for (auto&& result : results) {
+ std::stringstream ss;
+ Printer(result.get(), lineCounter++, ss);
+ output.append(ss.str());
+ }
+
+ return output;
+}
+
+BOOST_PYTHON_MODULE(libamunmt)
+{
+ boost::python::def("init", init);
+ boost::python::def("translate", translate);
+}
diff --git a/src/python/test.py b/src/python/test.py
new file mode 100644
index 00000000..cba7eee3
--- /dev/null
+++ b/src/python/test.py
@@ -0,0 +1,13 @@
+import libamunmt as nmt
+import sys
+
+nmt.init(sys.argv[1])
+
+sentences = []
+for line in sys.stdin:
+ sentences.append(line.rstrip())
+
+output = nmt.translate(sentences)
+
+for line in output:
+ sys.stdout.write(line)