Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.freedesktop.org/gstreamer/gst-plugins-rs.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/audio
diff options
context:
space:
mode:
authorMathieu Duponchelle <mathieu@centricular.com>2020-05-29 00:55:00 +0300
committerMathieu Duponchelle <mduponchelle1@gmail.com>2020-05-29 23:21:34 +0300
commit815aa80789fc62a23f120528dd46cbe45c3d58d7 (patch)
tree58e8d683cec9b10c79098e7c1dcf479af5192af3 /audio
parent08da51744beb31835ccac36c904c6269add8d1ba (diff)
awstranscriber: implement use-partial-results property
The current implementation only makes use of non-partial results, requiring a crazy high latency. With this mode, we use items from partial results when they're older than latency - 2 * GRANULARITY_MS. Depending on the latency that the user has set this may result in reduced accuracy, the default latency has been modified to a pretty conservative sweet spot of 8 seconds. This complexifies the code a bit, as items aren't identified by AWS, and their timings can change. Part-of: <https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs/-/merge_requests/348>
Diffstat (limited to 'audio')
-rw-r--r--audio/transcribe/src/aws_transcribe_parse.rs190
1 files changed, 153 insertions, 37 deletions
diff --git a/audio/transcribe/src/aws_transcribe_parse.rs b/audio/transcribe/src/aws_transcribe_parse.rs
index 0656c37bd..bd76108cd 100644
--- a/audio/transcribe/src/aws_transcribe_parse.rs
+++ b/audio/transcribe/src/aws_transcribe_parse.rs
@@ -108,10 +108,11 @@ static RUNTIME: Lazy<runtime::Runtime> = Lazy::new(|| {
.unwrap()
});
-const DEFAULT_LATENCY_MS: u32 = 30000;
+const DEFAULT_LATENCY_MS: u32 = 8000;
+const DEFAULT_USE_PARTIAL_RESULTS: bool = true;
const GRANULARITY_MS: u32 = 100;
-static PROPERTIES: [subclass::Property; 2] = [
+static PROPERTIES: [subclass::Property; 3] = [
subclass::Property("language-code", |name| {
glib::ParamSpec::string(
name,
@@ -123,12 +124,21 @@ static PROPERTIES: [subclass::Property; 2] = [
glib::ParamFlags::READWRITE,
)
}),
+ subclass::Property("use-partial-results", |name| {
+ glib::ParamSpec::boolean(
+ name,
+ "Latency",
+ "Whether partial results from AWS should be used",
+ DEFAULT_USE_PARTIAL_RESULTS,
+ glib::ParamFlags::READWRITE,
+ )
+ }),
subclass::Property("latency", |name| {
glib::ParamSpec::uint(
name,
"Latency",
"Amount of milliseconds to allow AWS transcribe",
- GRANULARITY_MS,
+ 2 * GRANULARITY_MS,
std::u32::MAX,
DEFAULT_LATENCY_MS,
glib::ParamFlags::READWRITE,
@@ -140,6 +150,7 @@ static PROPERTIES: [subclass::Property; 2] = [
struct Settings {
latency_ms: u32,
language_code: Option<String>,
+ use_partial_results: bool,
}
impl Default for Settings {
@@ -147,6 +158,7 @@ impl Default for Settings {
Self {
latency_ms: DEFAULT_LATENCY_MS,
language_code: Some("en-US".to_string()),
+ use_partial_results: DEFAULT_USE_PARTIAL_RESULTS,
}
}
}
@@ -162,6 +174,8 @@ struct State {
buffers: VecDeque<gst::Buffer>,
send_eos: bool,
discont: bool,
+ last_partial_end_time: gst::ClockTime,
+ partial_alternative: Option<TranscriptAlternative>,
}
impl Default for State {
@@ -177,6 +191,8 @@ impl Default for State {
buffers: VecDeque::new(),
send_eos: false,
discont: true,
+ last_partial_end_time: gst::CLOCK_TIME_NONE,
+ partial_alternative: None,
}
}
}
@@ -257,13 +273,19 @@ impl Transcriber {
let (latency, now, mut last_position, send_eos, seqnum) = {
let mut state = self.state.lock().unwrap();
- let send_eos = state.send_eos && state.buffers.is_empty();
-
+ // Multiply GRANULARITY by 2 in order to not send buffers that
+ // are less than GRANULARITY milliseconds away too late
let latency: gst::ClockTime = (self.settings.lock().unwrap().latency_ms as u64
- - GRANULARITY_MS as u64)
+ - (2 * GRANULARITY_MS) as u64)
* gst::MSECOND;
let now = element.get_current_running_time();
+ if let Some(alternative) = state.partial_alternative.take() {
+ self.enqueue(element, &mut state, &alternative, true, latency, now);
+ state.partial_alternative = Some(alternative);
+ }
+ let send_eos = state.send_eos && state.buffers.is_empty();
+
while let Some(buf) = state.buffers.front() {
if now - buf.get_pts() > latency {
/* Safe unwrap, we know we have an item */
@@ -352,6 +374,64 @@ impl Transcriber {
true
}
+ fn enqueue(
+ &self,
+ element: &gst::Element,
+ state: &mut State,
+ alternative: &TranscriptAlternative,
+ partial: bool,
+ latency: gst::ClockTime,
+ now: gst::ClockTime,
+ ) {
+ for item in &alternative.items {
+ let mut start_time: gst::ClockTime =
+ ((item.start_time as f64 * 1_000_000_000.0) as u64).into();
+ let mut end_time: gst::ClockTime =
+ ((item.end_time as f64 * 1_000_000_000.0) as u64).into();
+
+ if start_time <= state.last_partial_end_time {
+ /* Already sent (hopefully) */
+ continue;
+ } else if !partial || start_time + latency < now {
+ /* Should be sent now */
+ gst_debug!(CAT, obj: element, "Item is ready: {}", item.content);
+ let mut buf = gst::Buffer::from_mut_slice(item.content.clone().into_bytes());
+ state.last_partial_end_time = end_time;
+
+ {
+ let buf = buf.get_mut().unwrap();
+
+ if state.discont {
+ buf.set_flags(gst::BufferFlags::DISCONT);
+ state.discont = false;
+ }
+
+ if start_time < state.out_segment.get_position() {
+ gst_debug!(
+ CAT,
+ obj: element,
+ "Adjusting item timing({:?} < {:?})",
+ start_time,
+ state.out_segment.get_position()
+ );
+ start_time = state.out_segment.get_position();
+ if end_time < start_time {
+ end_time = start_time;
+ }
+ }
+
+ buf.set_pts(start_time);
+ buf.set_duration(end_time - start_time);
+ }
+
+ state.buffers.push_back(buf);
+ } else {
+ /* Doesn't need to be sent yet */
+ break;
+ }
+ }
+ }
+
fn loop_fn(
&self,
element: &gst::Element,
@@ -417,50 +497,78 @@ impl Transcriber {
if !transcript.transcript.results.is_empty() {
let mut result = transcript.transcript.results.remove(0);
+ let use_partial_results = self.settings.lock().unwrap().use_partial_results;
if !result.is_partial && !result.alternatives.is_empty() {
- let alternative = result.alternatives.remove(0);
- gst_info!(CAT, obj: element, "Transcript: {}", alternative.transcript);
+ if !use_partial_results {
+ let alternative = result.alternatives.remove(0);
+ gst_info!(
+ CAT,
+ obj: element,
+ "Transcript: {}",
+ alternative.transcript
+ );
- let mut start_time: gst::ClockTime =
- ((result.start_time as f64 * 1_000_000_000.0) as u64).into();
- let end_time: gst::ClockTime =
- ((result.end_time as f64 * 1_000_000_000.0) as u64).into();
+ let mut start_time: gst::ClockTime =
+ ((result.start_time as f64 * 1_000_000_000.0) as u64).into();
+ let end_time: gst::ClockTime =
+ ((result.end_time as f64 * 1_000_000_000.0) as u64).into();
- let mut state = self.state.lock().unwrap();
- let position = state.out_segment.get_position();
+ let mut state = self.state.lock().unwrap();
+ let position = state.out_segment.get_position();
- if end_time < position {
- gst_warning!(CAT, obj: element,
- "Received transcript is too late by {:?}, dropping, consider increasing the latency",
- position - start_time);
- } else {
- if start_time < position {
+ if end_time < position {
gst_warning!(CAT, obj: element,
- "Received transcript is too late by {:?}, clipping, consider increasing the latency",
+ "Received transcript is too late by {:?}, dropping, consider increasing the latency",
position - start_time);
- start_time = position;
- }
+ } else {
+ if start_time < position {
+ gst_warning!(CAT, obj: element,
+ "Received transcript is too late by {:?}, clipping, consider increasing the latency",
+ position - start_time);
+ start_time = position;
+ }
- let mut buf = gst::Buffer::from_mut_slice(
- alternative.transcript.into_bytes(),
- );
+ let mut buf = gst::Buffer::from_mut_slice(
+ alternative.transcript.into_bytes(),
+ );
- {
- let buf = buf.get_mut().unwrap();
+ {
+ let buf = buf.get_mut().unwrap();
- if state.discont {
- buf.set_flags(gst::BufferFlags::DISCONT);
- state.discont = false;
- }
+ if state.discont {
+ buf.set_flags(gst::BufferFlags::DISCONT);
+ state.discont = false;
+ }
- buf.set_pts(start_time);
- buf.set_duration(end_time - start_time);
- }
+ buf.set_pts(start_time);
+ buf.set_duration(end_time - start_time);
+ }
- gst_debug!(CAT, obj: element, "Adding pending buffer: {:?}", buf);
+ gst_debug!(
+ CAT,
+ obj: element,
+ "Adding pending buffer: {:?}",
+ buf
+ );
- state.buffers.push_back(buf);
+ state.buffers.push_back(buf);
+ }
+ } else {
+ let alternative = result.alternatives.remove(0);
+ let mut state = self.state.lock().unwrap();
+ self.enqueue(
+ element,
+ &mut state,
+ &alternative,
+ false,
+ 0.into(),
+ 0.into(),
+ );
+ state.partial_alternative = None;
}
+ } else if !result.alternatives.is_empty() && use_partial_results {
+ let mut state = self.state.lock().unwrap();
+ state.partial_alternative = Some(result.alternatives.remove(0));
}
}
Ok(())
@@ -1001,6 +1109,10 @@ impl ObjectImpl for Transcriber {
let mut settings = self.settings.lock().unwrap();
settings.latency_ms = value.get_some().expect("type checked upstream");
}
+ subclass::Property("use-partial-results", ..) => {
+ let mut settings = self.settings.lock().unwrap();
+ settings.use_partial_results = value.get_some().expect("type checked upstream");
+ }
_ => unimplemented!(),
}
}
@@ -1017,6 +1129,10 @@ impl ObjectImpl for Transcriber {
let settings = self.settings.lock().unwrap();
Ok(settings.latency_ms.to_value())
}
+ subclass::Property("use-partial-results", ..) => {
+ let settings = self.settings.lock().unwrap();
+ Ok(settings.use_partial_results.to_value())
+ }
_ => unimplemented!(),
}
}