Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/moses
diff options
context:
space:
mode:
authorLane Schwartz <dowobeha@gmail.com>2016-11-14 22:26:34 +0300
committerLane Schwartz <dowobeha@gmail.com>2016-11-14 22:26:34 +0300
commit05006bf1e2c68395a63ebd21a4f7ee56e38f260f (patch)
treedc589aab6af2ea30b8e110921730fe0b1fe8edef /moses
parentea9d3b7f3ea5c0a9210c906b96824695b8a77ede (diff)
Allow XML-RPC requests to update weights
Diffstat (limited to 'moses')
-rw-r--r--moses/LM/Reloading.h26
-rw-r--r--moses/server/TranslationRequest.cpp57
2 files changed, 83 insertions, 0 deletions
diff --git a/moses/LM/Reloading.h b/moses/LM/Reloading.h
index 7075cb429..88f8e8869 100644
--- a/moses/LM/Reloading.h
+++ b/moses/LM/Reloading.h
@@ -48,6 +48,32 @@ public:
virtual void InitializeForInput(ttasksptr const& ttask) {
VERBOSE(1, "ReloadingLM InitializeForInput" << std::endl);
+ // The context scope object for this translation task
+ // contains a map of translation task-specific data
+ boost::shared_ptr<Moses::ContextScope> contextScope = ttask->GetScope();
+
+ // The key to the map is this object
+ void const* key = static_cast<void const*>(this);
+
+ // The value stored in the map is a string representing a phrase table
+ boost::shared_ptr<string> value = contextScope->get<string>(key);
+
+ // Create a stream to read the phrase table data
+ stringstream strme(*(value.get()));
+
+ ofstream tmp;
+ tmp.open(m_file.c_str());
+
+ // Read the phrase table data, one line at a time
+ string line;
+ while (getline(strme, line)) {
+
+ tmp << line << "\n";
+
+ }
+
+ tmp.close();
+
LanguageModelKen<Model>::LoadModel(m_file, m_lazy ? util::LAZY : util::POPULATE_OR_READ);
};
diff --git a/moses/server/TranslationRequest.cpp b/moses/server/TranslationRequest.cpp
index 4e97cff6a..e1821a265 100644
--- a/moses/server/TranslationRequest.cpp
+++ b/moses/server/TranslationRequest.cpp
@@ -26,6 +26,7 @@ using Moses::FindPhraseDictionary;
using Moses::Sentence;
using Moses::TokenizeMultiCharSeparator;
using Moses::FeatureFunction;
+using Moses::Scan;
boost::shared_ptr<TranslationRequest>
TranslationRequest::
@@ -352,6 +353,62 @@ parse_request(std::map<std::string, xmlrpc_c::value> const& params)
}
}
+ si = params.find("weights");
+ if (si != params.end())
+ {
+
+ boost::unordered_map<string, FeatureFunction*> map;
+ {
+ const vector<FeatureFunction*> &ffs = FeatureFunction::GetFeatureFunctions();
+ BOOST_FOREACH(FeatureFunction* const& ff, ffs) {
+ map[ff->GetScoreProducerDescription()] = ff;
+ }
+ }
+
+ string allValues = xmlrpc_c::value_string(si->second);
+
+ BOOST_FOREACH(string values, TokenizeMultiCharSeparator(allValues, "\t")) {
+
+ vector<string> record = TokenizeMultiCharSeparator(values, "=");
+
+ if (record.size() == 2) {
+ string featureName = record[0];
+ string featureWeights = record[1];
+
+ boost::unordered_map<string, FeatureFunction*>::iterator ffi = map.find(featureName);
+
+ if (ffi != map.end()) {
+ FeatureFunction* ff = ffi->second;
+
+ size_t prevNumWeights = ff->GetNumScoreComponents();
+
+ vector<float> ffWeights;
+ BOOST_FOREACH(string weight, TokenizeMultiCharSeparator(featureWeights, " ")) {
+ ffWeights.push_back(Scan<float>(weight));
+ }
+
+ if (ffWeights.size() == ff->GetNumScoreComponents()) {
+
+ // XXX: This is NOT thread-safe
+ Moses::StaticData::InstanceNonConst().SetWeights(ff, ffWeights);
+ VERBOSE(1, "WARNING: THIS IS NOT THREAD-SAFE!\tUpdating weights for " << featureName << " to " << featureWeights << "\n");
+
+ } else {
+ TRACE_ERR("ERROR: Unable to update weights for " << featureName << " because " << ff->GetNumScoreComponents() << " weights are required but only " << ffWeights.size() << " were provided\n");
+ }
+
+ } else {
+ TRACE_ERR("ERROR: No FeatureFunction with name " << featureName << ", no weight update\n");
+ }
+
+ } else {
+ TRACE_ERR("WARNING: XML-RPC weights update was improperly formatted:\t" << values << "\n");
+ }
+
+ }
+
+ }
+
// // biased sampling for suffix-array-based sampling phrase table?
// if ((si = params.find("bias")) != params.end())