Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/sdroege/gst-plugin-rs.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/net/aws
diff options
context:
space:
mode:
authorFrançois Laignel <francois@centricular.com>2023-03-16 20:20:08 +0300
committerFrançois Laignel <francois@centricular.com>2023-03-16 22:29:31 +0300
commit2b32d00589c23d3914f1b825074d96532df5d930 (patch)
treef452f45077b1266aa6a108e1fd2567c7140ed7a2 /net/aws
parent5a5ca76d9d750c58dd801a58d751849044dd2c01 (diff)
net/aws/transcriber: use two queues for sending transcript items
* A queue dedicated to transcript items not intended for translation. * A queue dedicated to transcript items intended for translation. The items are enqueued after a separator is detected or translate-lookahead was reached. Part-of: <https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs/-/merge_requests/1137>
Diffstat (limited to 'net/aws')
-rw-r--r--net/aws/src/transcriber/imp.rs520
-rw-r--r--net/aws/src/transcriber/transcribe.rs104
-rw-r--r--net/aws/src/transcriber/translate.rs74
3 files changed, 374 insertions, 324 deletions
diff --git a/net/aws/src/transcriber/imp.rs b/net/aws/src/transcriber/imp.rs
index 0c99c760..cc3f8ad2 100644
--- a/net/aws/src/transcriber/imp.rs
+++ b/net/aws/src/transcriber/imp.rs
@@ -29,12 +29,12 @@ use futures::prelude::*;
use tokio::{runtime, sync::broadcast, task};
use std::collections::{BTreeSet, VecDeque};
-use std::sync::Mutex;
+use std::sync::{Arc, Mutex};
use once_cell::sync::Lazy;
-use super::transcribe::{TranscriberLoop, TranscriptEvent, TranscriptItem, TranscriptionSettings};
-use super::translate::{TranslateLoop, TranslateQueue, TranslatedItem};
+use super::transcribe::{TranscriberSettings, TranscriberStream, TranscriptEvent, TranscriptItem};
+use super::translate::{TranslateLoop, TranslatedItem};
use super::{
AwsTranscriberResultStability, AwsTranscriberVocabularyFilterMethod,
TranslationTokenizationMethod, CAT,
@@ -148,6 +148,7 @@ struct State {
srcpads: BTreeSet<super::TranslateSrcPad>,
pad_serial: u32,
seqnum: gst::Seqnum,
+ start_time: Option<gst::ClockTime>,
}
impl Default for State {
@@ -158,6 +159,7 @@ impl Default for State {
srcpads: Default::default(),
pad_serial: 0,
seqnum: gst::Seqnum::next(),
+ start_time: None,
}
}
}
@@ -168,7 +170,9 @@ pub struct Transcriber {
settings: Mutex<Settings>,
state: Mutex<State>,
pub(super) aws_config: Mutex<Option<aws_config::SdkConfig>>,
- // sender to broadcast transcript items to the translate src pads.
+ // sender to broadcast transcript items to the src pads for translation.
+ transcript_event_for_translate_tx: broadcast::Sender<TranscriptEvent>,
+ // sender to broadcast transcript items to the src pads, not intended for translation.
transcript_event_tx: broadcast::Sender<TranscriptEvent>,
}
@@ -276,7 +280,10 @@ impl Transcriber {
) -> Result<gst::FlowSuccess, gst::FlowError> {
gst::log!(CAT, obj: pad, "Handling {buffer:?}");
- self.ensure_connection();
+ self.ensure_connection().map_err(|err| {
+ gst::element_imp_error!(self, gst::StreamError::Failed, ["Streaming failed: {err}"]);
+ gst::FlowError::Error
+ })?;
let Some(mut buffer_tx) = self.state.lock().unwrap().buffer_tx.take() else {
gst::log!(CAT, obj: pad, "Flushing");
@@ -292,12 +299,82 @@ impl Transcriber {
Ok(gst::FlowSuccess::Ok)
}
+}
+
+#[derive(Default)]
+struct TranslateQueue {
+ items: VecDeque<TranscriptItem>,
+}
+
+impl TranslateQueue {
+ fn is_empty(&self) -> bool {
+ self.items.is_empty()
+ }
+
+ /// Pushes the provided item.
+ ///
+ /// Returns `Some(..)` if items are ready for translation.
+ fn push(&mut self, transcript_item: &TranscriptItem) -> Option<Vec<TranscriptItem>> {
+ // Keep track of the item individually so we can schedule translation precisely.
+ self.items.push_back(transcript_item.clone());
+
+ if transcript_item.is_punctuation {
+ // This makes it a good chunk for translation.
+ // Concatenate as a single item for translation
+
+ return Some(self.items.drain(..).collect());
+ }
+
+ // Regular case: no separator detected, don't push transcript items
+ // to translation now. They will be pushed either if a punctuation
+ // is found or of a `dequeue()` is requested.
+
+ None
+ }
+
+ /// Dequeues items from the specified `deadline` up to `lookahead`.
+ ///
+ /// Returns `Some(..)` if some items match the criteria.
+ fn dequeue(
+ &mut self,
+ latency: gst::ClockTime,
+ threshold: gst::ClockTime,
+ lookahead: gst::ClockTime,
+ ) -> Option<Vec<TranscriptItem>> {
+ let first_pts = self.items.front()?.pts;
+ if first_pts + latency > threshold {
+ // First item is too early to be sent to translation now
+ // we can wait for more items to accumulate.
+ return None;
+ }
+
+ // Can't wait any longer to send the first item to translation
+ // Try to get up to lookahead worth of items to improve translation accuracy
+ let limit = first_pts + lookahead;
+
+ let mut items_acc = vec![self.items.pop_front().unwrap()];
+ while let Some(item) = self.items.front() {
+ if item.pts > limit {
+ break;
+ }
+
+ items_acc.push(self.items.pop_front().unwrap());
+ }
- fn ensure_connection(&self) {
+ Some(items_acc)
+ }
+
+ fn drain(&mut self) -> impl Iterator<Item = TranscriptItem> + '_ {
+ self.items.drain(..)
+ }
+}
+
+impl Transcriber {
+ fn ensure_connection(&self) -> Result<(), gst::ErrorMessage> {
let mut state = self.state.lock().unwrap();
if state.buffer_tx.is_some() {
- return;
+ return Ok(());
}
let settings = self.settings.lock().unwrap();
@@ -306,21 +383,116 @@ impl Transcriber {
let s = in_caps.structure(0).unwrap();
let sample_rate = s.get::<i32>("rate").unwrap();
- let transcription_settings = TranscriptionSettings::from(&settings, sample_rate);
+ let transcription_settings = TranscriberSettings::from(&settings, sample_rate);
let (buffer_tx, buffer_rx) = mpsc::channel(1);
- let transcriber_loop = TranscriberLoop::new(
+ let _enter = RUNTIME.enter();
+ let mut transcriber_stream = futures::executor::block_on(TranscriberStream::try_new(
self,
transcription_settings,
settings.lateness,
buffer_rx,
- self.transcript_event_tx.clone(),
- );
- let transcriber_loop_handle = RUNTIME.spawn(transcriber_loop.run());
+ ))?;
+
+ // Latency budget for an item to be pushed to stream on time
+ // Margin:
+ // - 2 * GRANULARITY: to make sure we don't push items up to GRANULARITY late.
+ // - 1 * GRANULARITY: extra margin to account for additional overheads.
+ let latency = settings.transcribe_latency.saturating_sub(3 * GRANULARITY);
+ let translate_lookahead = settings.translate_lookahead;
+ let mut translate_queue = TranslateQueue::default();
+ let imp = self.ref_counted();
+ let transcriber_loop_handle = RUNTIME.spawn(async move {
+ loop {
+ // This is to make sure we send items on a timely basis or at least Gap events.
+ let timeout = tokio::time::sleep(GRANULARITY.into()).fuse();
+ futures::pin_mut!(timeout);
+
+ let transcriber_next = transcriber_stream.next().fuse();
+ futures::pin_mut!(transcriber_next);
+
+ // `transcriber_next` takes precedence over `timeout`
+ // because we don't want to loose any incoming items.
+ let res = futures::select_biased! {
+ event = transcriber_next => Some(event?),
+ _ = timeout => None,
+ };
+
+ use TranscriptEvent::*;
+ match res {
+ None => (),
+ Some(Items(items)) => {
+ if imp.transcript_event_tx.receiver_count() > 0 {
+ let _ = imp.transcript_event_tx.send(Items(items.clone()));
+ }
+
+ if imp.transcript_event_for_translate_tx.receiver_count() > 0 {
+ for item in items.iter() {
+ if let Some(items_to_translate) = translate_queue.push(item) {
+ let _ = imp
+ .transcript_event_for_translate_tx
+ .send(Items(items_to_translate.into()));
+ }
+ }
+ }
+ }
+ Some(Eos) => {
+ gst::debug!(CAT, imp: imp, "Transcriber loop sending EOS");
+
+ if imp.transcript_event_tx.receiver_count() > 0 {
+ let _ = imp.transcript_event_tx.send(Eos);
+ }
+
+ if imp.transcript_event_for_translate_tx.receiver_count() > 0 {
+ let items_to_translate: Vec<TranscriptItem> =
+ translate_queue.drain().collect();
+ let _ = imp
+ .transcript_event_for_translate_tx
+ .send(Items(items_to_translate.into()));
+
+ let _ = imp.transcript_event_for_translate_tx.send(Eos);
+ }
+
+ break;
+ }
+ }
+
+ if imp.transcript_event_for_translate_tx.receiver_count() > 0 {
+ // Check if we need to push items for translation
+
+ let Some((start_time, now)) = imp.get_start_time_and_now() else {
+ continue;
+ };
+
+ if !translate_queue.is_empty() {
+ let threshold = now - start_time;
+
+ if let Some(items_to_translate) =
+ translate_queue.dequeue(latency, threshold, translate_lookahead)
+ {
+ gst::debug!(
+ CAT,
+ imp: imp,
+ "Forcing to translation (threshold {threshold}): {items_to_translate:?}"
+ );
+ let _ = imp
+ .transcript_event_for_translate_tx
+ .send(Items(items_to_translate.into()));
+ }
+ }
+ }
+ }
+
+ gst::debug!(CAT, imp: imp, "Exiting transcriber loop");
+
+ Ok(())
+ });
state.transcriber_loop_handle = Some(transcriber_loop_handle);
state.buffer_tx = Some(buffer_tx);
+
+ Ok(())
}
fn prepare(&self) -> Result<(), gst::ErrorMessage> {
@@ -382,6 +554,18 @@ impl Transcriber {
}
gst::info!(CAT, imp: self, "Unprepared");
}
+
+ fn get_start_time_and_now(&self) -> Option<(gst::ClockTime, gst::ClockTime)> {
+ let now = self.obj().current_running_time()?;
+
+ let mut state = self.state.lock().unwrap();
+
+ if state.start_time.is_none() {
+ state.start_time = Some(now);
+ }
+
+ Some((state.start_time.unwrap(), now))
+ }
}
#[glib::object_subclass]
@@ -438,6 +622,7 @@ impl ObjectSubclass for Transcriber {
// Setting the channel capacity so that a TranslateSrcPad that would lag
// behind for some reasons get a chance to catch-up without loosing items.
// Receiver will be created by subscribing to sender later.
+ let (transcript_event_for_translate_tx, _) = broadcast::channel(128);
let (transcript_event_tx, _) = broadcast::channel(128);
Self {
@@ -446,6 +631,7 @@ impl ObjectSubclass for Transcriber {
settings: Default::default(),
state: Default::default(),
aws_config: Default::default(),
+ transcript_event_for_translate_tx,
transcript_event_tx,
}
}
@@ -876,51 +1062,93 @@ struct TranslationPadTask {
elem: super::Transcriber,
transcript_event_rx: broadcast::Receiver<TranscriptEvent>,
needs_translate: bool,
- translate_queue: TranslateQueue,
translate_loop_handle: Option<task::JoinHandle<Result<(), gst::ErrorMessage>>>,
- to_translate_tx: Option<mpsc::Sender<Vec<TranscriptItem>>>,
+ to_translate_tx: Option<mpsc::Sender<Arc<Vec<TranscriptItem>>>>,
from_translate_rx: Option<mpsc::Receiver<Vec<TranslatedItem>>>,
- translate_latency: gst::ClockTime,
- translate_lookahead: gst::ClockTime,
send_events: bool,
output_items: VecDeque<OutputItem>,
our_latency: gst::ClockTime,
seqnum: gst::Seqnum,
send_eos: bool,
pending_translations: usize,
- start_time: Option<gst::ClockTime>,
}
impl TranslationPadTask {
- fn try_new(
+ async fn try_new(
pad: &TranslateSrcPad,
elem: super::Transcriber,
- transcript_event_rx: broadcast::Receiver<TranscriptEvent>,
) -> Result<TranslationPadTask, gst::ErrorMessage> {
- let mut this = TranslationPadTask {
+ let mut translation_loop = None;
+ let mut translate_loop_handle = None;
+ let mut to_translate_tx = None;
+ let mut from_translate_rx = None;
+
+ let (our_latency, transcript_event_rx, needs_translate);
+
+ {
+ let elem_imp = elem.imp();
+ let elem_settings = elem_imp.settings.lock().unwrap();
+
+ let pad_settings = pad.settings.lock().unwrap();
+
+ our_latency = TranslateSrcPad::our_latency(&elem_settings, &pad_settings);
+ if our_latency + elem_settings.lateness <= 2 * GRANULARITY {
+ let err = format!(
+ "total latency + lateness must be greater than {}",
+ 2 * GRANULARITY
+ );
+ gst::error!(CAT, imp: pad, "{err}");
+ return Err(gst::error_msg!(gst::LibraryError::Settings, ["{err}"]));
+ }
+
+ needs_translate = TranslateSrcPad::needs_translation(
+ &elem_settings.language_code,
+ pad_settings.language_code.as_deref(),
+ );
+
+ if needs_translate {
+ let (to_loop_tx, to_loop_rx) = mpsc::channel(64);
+ let (from_loop_tx, from_loop_rx) = mpsc::channel(64);
+
+ translation_loop = Some(TranslateLoop::new(
+ elem_imp,
+ pad,
+ &elem_settings.language_code,
+ pad_settings.language_code.as_deref().unwrap(),
+ pad_settings.tokenization_method,
+ to_loop_rx,
+ from_loop_tx,
+ ));
+
+ to_translate_tx = Some(to_loop_tx);
+ from_translate_rx = Some(from_loop_rx);
+
+ transcript_event_rx = elem_imp.transcript_event_for_translate_tx.subscribe();
+ } else {
+ transcript_event_rx = elem_imp.transcript_event_tx.subscribe();
+ }
+ }
+
+ if let Some(translation_loop) = translation_loop {
+ translation_loop.check_language().await?;
+ translate_loop_handle = Some(RUNTIME.spawn(translation_loop.run()));
+ }
+
+ Ok(TranslationPadTask {
pad: pad.ref_counted(),
elem,
transcript_event_rx,
- needs_translate: false,
- translate_queue: TranslateQueue::default(),
- translate_loop_handle: None,
- to_translate_tx: None,
- from_translate_rx: None,
- translate_latency: DEFAULT_TRANSLATE_LATENCY,
- translate_lookahead: DEFAULT_TRANSLATE_LOOKAHEAD,
+ needs_translate,
+ translate_loop_handle,
+ to_translate_tx,
+ from_translate_rx,
send_events: true,
output_items: VecDeque::new(),
- our_latency: DEFAULT_TRANSCRIBE_LATENCY,
+ our_latency,
seqnum: gst::Seqnum::next(),
send_eos: false,
pending_translations: 0,
- start_time: None,
- };
-
- let _enter_guard = RUNTIME.enter();
- futures::executor::block_on(this.init_translate())?;
-
- Ok(this)
+ })
}
}
@@ -958,11 +1186,9 @@ impl TranslationPadTask {
let transcript_event_rx = self.transcript_event_rx.recv().fuse();
futures::pin_mut!(transcript_event_rx);
- // `timeout` takes precedence over `transcript_events` reception
- // because we may need to `dequeue` `items` or push a `Gap` event
- // before current latency budget is exhausted.
+ // `transcript_event_rx` takes precedence over `timeout`
+ // because we don't want to loose any incoming items.
futures::select_biased! {
- _ = timeout => (),
items_res = transcript_event_rx => {
use TranscriptEvent::*;
use broadcast::error::RecvError;
@@ -983,6 +1209,7 @@ impl TranslationPadTask {
}
}
}
+ _ = timeout => (),
}
Ok(())
@@ -999,121 +1226,100 @@ impl TranslationPadTask {
return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"]));
}
- let transcript_items = {
+ let items_to_translate = {
// This is to make sure we send items on a timely basis or at least Gap events.
let timeout = tokio::time::sleep(GRANULARITY.into()).fuse();
futures::pin_mut!(timeout);
- let from_translate_rx = self
- .from_translate_rx
- .as_mut()
- .expect("from_translation chan must be available in translation mode");
-
let transcript_event_rx = self.transcript_event_rx.recv().fuse();
futures::pin_mut!(transcript_event_rx);
- // `timeout` takes precedence over `transcript_events` reception
- // because we may need to `dequeue` `items` or push a `Gap` event
- // before current latency budget is exhausted.
+ // `transcript_event_rx` takes precedence over `timeout`
+ // because we don't want to loose any incoming items.
futures::select_biased! {
- _ = timeout => return Ok(()),
- translated_items = from_translate_rx.next() => {
- let Some(translated_items) = translated_items else {
- const ERR: &str = "translation chan terminated";
- gst::debug!(CAT, imp: self.pad, "{ERR}");
- return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"]));
- };
-
- self.output_items.extend(translated_items.into_iter().map(Into::into));
- self.pending_translations = self.pending_translations.saturating_sub(1);
-
- return Ok(());
- }
items_res = transcript_event_rx => {
use TranscriptEvent::*;
use broadcast::error::RecvError;
match items_res {
- Ok(Items(transcript_items)) => transcript_items,
+ Ok(Items(items_to_translate)) => Some(items_to_translate),
Ok(Eos) => {
gst::debug!(CAT, imp: self.pad, "Got eos");
self.send_eos = true;
- return Ok(());
+ None
}
Err(RecvError::Lagged(nb_msg)) => {
gst::warning!(CAT, imp: self.pad, "Missed {nb_msg} transcript sets");
- return Ok(());
+ None
}
Err(RecvError::Closed) => {
gst::debug!(CAT, imp: self.pad, "Transcript chan terminated: setting eos");
self.send_eos = true;
- return Ok(());
+ None
}
}
}
+ _ = timeout => None,
}
};
- for items in transcript_items.iter() {
- if let Some(items_to_translate) = self.translate_queue.push(items) {
- self.send_for_translation(items_to_translate).await?;
+ if let Some(items_to_translate) = items_to_translate {
+ if !items_to_translate.is_empty() {
+ let res = self
+ .to_translate_tx
+ .as_mut()
+ .expect("to_translation chan must be available in translation mode")
+ .send(items_to_translate)
+ .await;
+
+ if res.is_err() {
+ const ERR: &str = "to_translation chan terminated";
+ gst::debug!(CAT, imp: self.pad, "{ERR}");
+ return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"]));
+ }
+
+ self.pending_translations += 1;
}
}
- Ok(())
- }
+ // Check pending translated items
+ let from_translate_rx = self
+ .from_translate_rx
+ .as_mut()
+ .expect("from_translation chan must be available in translation mode");
- async fn dequeue_for_translation(
- &mut self,
- start_time: gst::ClockTime,
- now: gst::ClockTime,
- ) -> Result<(), gst::ErrorMessage> {
- if !self.translate_queue.is_empty() {
- // Latency budget for an item to be pushed to stream on time
- // Margin:
- // - 2 * GRANULARITY: to make sure we don't push items up to GRANULARITY late.
- // - 1 * GRANULARITY: extra margin to account for additional overheads.
- let latency = self.our_latency.saturating_sub(3 * GRANULARITY);
-
- // Estimated time of arrival for an item sent to translation now.
- // (in transcript item ts base)
- let translation_eta = now + self.translate_latency - start_time;
-
- if let Some(items_to_translate) =
- self.translate_queue
- .dequeue(latency, translation_eta, self.translate_lookahead)
- {
- gst::debug!(CAT, imp: self.pad, "Forcing to translation: {items_to_translate:?}");
- self.send_for_translation(items_to_translate).await?;
- }
+ while let Ok(translated_items) = from_translate_rx.try_next() {
+ let Some(translated_items) = translated_items else {
+ const ERR: &str = "translation chan terminated";
+ gst::debug!(CAT, imp: self.pad, "{ERR}");
+ return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"]));
+ };
+
+ self.output_items
+ .extend(translated_items.into_iter().map(Into::into));
+ self.pending_translations = self.pending_translations.saturating_sub(1);
}
Ok(())
}
async fn dequeue(&mut self) -> bool {
- let (now, start_time, mut last_position, mut discont_pending);
- {
- let mut pad_state = self.pad.state.lock().unwrap();
-
- let Some(cur_rt) = self.elem.current_running_time() else {
- // Wait for the clock to be available
- return true;
- };
- now = cur_rt;
+ let Some((start_time, now)) = self.elem.imp().get_start_time_and_now() else {
+ // Wait for the clock to be available
+ return true;
+ };
- if self.start_time.is_none() {
- self.start_time = Some(now);
- pad_state.out_segment.set_position(now);
- }
+ let (mut last_position, mut discont_pending) = {
+ let mut state = self.pad.state.lock().unwrap();
- start_time = self.start_time.unwrap();
- last_position = pad_state.out_segment.position().unwrap();
- discont_pending = pad_state.discont_pending;
- }
+ let last_position = if let Some(pos) = state.out_segment.position() {
+ pos
+ } else {
+ state.out_segment.set_position(start_time);
+ start_time
+ };
- if self.needs_translate && self.dequeue_for_translation(start_time, now).await.is_err() {
- return false;
- }
+ (last_position, state.discont_pending)
+ };
/* First, check our pending buffers */
while let Some(item) = self.output_items.front() {
@@ -1206,11 +1412,7 @@ impl TranslationPadTask {
}
}
- if self.send_eos
- && self.pending_translations == 0
- && self.output_items.is_empty()
- && self.translate_queue.is_empty()
- {
+ if self.send_eos && self.pending_translations == 0 && self.output_items.is_empty() {
/* We're EOS, we can pause and exit early */
let _ = self.pad.obj().pause_task();
@@ -1261,28 +1463,6 @@ impl TranslationPadTask {
true
}
- async fn send_for_translation(
- &mut self,
- transcript_items: Vec<TranscriptItem>,
- ) -> Result<(), gst::ErrorMessage> {
- let res = self
- .to_translate_tx
- .as_mut()
- .expect("to_translation chan must be available in translation mode")
- .send(transcript_items)
- .await;
-
- if res.is_err() {
- const ERR: &str = "to_translation chan terminated";
- gst::debug!(CAT, imp: self.pad, "{ERR}");
- return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"]));
- }
-
- self.pending_translations += 1;
-
- Ok(())
- }
-
fn ensure_init_events(&mut self) -> Result<(), gst::ErrorMessage> {
if !self.send_events {
return Ok(());
@@ -1332,62 +1512,6 @@ impl TranslationPadTask {
}
}
-impl TranslationPadTask {
- async fn init_translate(&mut self) -> Result<(), gst::ErrorMessage> {
- let mut translation_loop = None;
-
- {
- let elem_imp = self.elem.imp();
- let elem_settings = elem_imp.settings.lock().unwrap();
-
- let pad_settings = self.pad.settings.lock().unwrap();
-
- self.our_latency = TranslateSrcPad::our_latency(&elem_settings, &pad_settings);
- if self.our_latency + elem_settings.lateness <= 2 * GRANULARITY {
- let err = format!(
- "total latency + lateness must be greater than {}",
- 2 * GRANULARITY
- );
- gst::error!(CAT, imp: self.pad, "{err}");
- return Err(gst::error_msg!(gst::LibraryError::Settings, ["{err}"]));
- }
-
- self.translate_latency = elem_settings.translate_latency;
- self.translate_lookahead = elem_settings.translate_lookahead;
-
- self.needs_translate = TranslateSrcPad::needs_translation(
- &elem_settings.language_code,
- pad_settings.language_code.as_deref(),
- );
-
- if self.needs_translate {
- let (to_translate_tx, to_translate_rx) = mpsc::channel(64);
- let (from_translate_tx, from_translate_rx) = mpsc::channel(64);
-
- translation_loop = Some(TranslateLoop::new(
- elem_imp,
- &self.pad,
- &elem_settings.language_code,
- pad_settings.language_code.as_deref().unwrap(),
- pad_settings.tokenization_method,
- to_translate_rx,
- from_translate_tx,
- ));
-
- self.to_translate_tx = Some(to_translate_tx);
- self.from_translate_rx = Some(from_translate_rx);
- }
- }
-
- if let Some(translation_loop) = translation_loop {
- translation_loop.check_language().await?;
- self.translate_loop_handle = Some(RUNTIME.spawn(translation_loop.run()));
- }
-
- Ok(())
- }
-}
-
#[derive(Debug)]
struct TranslationPadState {
discont_pending: bool,
@@ -1422,8 +1546,8 @@ impl TranslateSrcPad {
gst::debug!(CAT, imp: self, "Starting task");
let elem = self.parent();
- let transcript_event_rx = elem.imp().transcript_event_tx.subscribe();
- let mut pad_task = TranslationPadTask::try_new(self, elem, transcript_event_rx)
+ let _enter = RUNTIME.enter();
+ let mut pad_task = futures::executor::block_on(TranslationPadTask::try_new(self, elem))
.map_err(|err| gst::loggable_error!(CAT, format!("Failed to start pad task {err}")))?;
let imp = self.ref_counted();
diff --git a/net/aws/src/transcriber/transcribe.rs b/net/aws/src/transcriber/transcribe.rs
index 97301380..d094d87a 100644
--- a/net/aws/src/transcriber/transcribe.rs
+++ b/net/aws/src/transcriber/transcribe.rs
@@ -15,7 +15,6 @@ use aws_sdk_transcribestreaming::model;
use futures::channel::mpsc;
use futures::prelude::*;
-use tokio::sync::broadcast;
use std::sync::Arc;
@@ -23,7 +22,7 @@ use super::imp::{Settings, Transcriber};
use super::CAT;
#[derive(Debug)]
-pub struct TranscriptionSettings {
+pub struct TranscriberSettings {
lang_code: model::LanguageCode,
sample_rate: i32,
vocabulary: Option<String>,
@@ -33,9 +32,9 @@ pub struct TranscriptionSettings {
results_stability: model::PartialResultsStability,
}
-impl TranscriptionSettings {
+impl TranscriberSettings {
pub(super) fn from(settings: &Settings, sample_rate: i32) -> Self {
- TranscriptionSettings {
+ TranscriberSettings {
lang_code: settings.language_code.as_str().into(),
sample_rate,
vocabulary: settings.vocabulary.clone(),
@@ -83,43 +82,30 @@ impl From<Vec<TranscriptItem>> for TranscriptEvent {
}
}
-pub struct TranscriberLoop {
+pub struct TranscriberStream {
imp: glib::subclass::ObjectImplRef<Transcriber>,
- client: aws_transcribe::Client,
- settings: Option<TranscriptionSettings>,
+ output: aws_transcribe::output::StartStreamTranscriptionOutput,
lateness: gst::ClockTime,
- buffer_rx: Option<mpsc::Receiver<gst::Buffer>>,
- transcript_items_tx: broadcast::Sender<TranscriptEvent>,
partial_index: usize,
}
-impl TranscriberLoop {
- pub fn new(
+impl TranscriberStream {
+ pub async fn try_new(
imp: &Transcriber,
- settings: TranscriptionSettings,
+ settings: TranscriberSettings,
lateness: gst::ClockTime,
buffer_rx: mpsc::Receiver<gst::Buffer>,
- transcript_items_tx: broadcast::Sender<TranscriptEvent>,
- ) -> Self {
- let aws_config = imp.aws_config.lock().unwrap();
- let aws_config = aws_config
- .as_ref()
- .expect("aws_config must be initialized at this stage");
-
- TranscriberLoop {
- imp: imp.ref_counted(),
- client: aws_transcribe::Client::new(aws_config),
- settings: Some(settings),
- lateness,
- buffer_rx: Some(buffer_rx),
- transcript_items_tx,
- partial_index: 0,
- }
- }
+ ) -> Result<Self, gst::ErrorMessage> {
+ let client = {
+ let aws_config = imp.aws_config.lock().unwrap();
+ let aws_config = aws_config
+ .as_ref()
+ .expect("aws_config must be initialized at this stage");
+ aws_transcribe::Client::new(aws_config)
+ };
- pub async fn run(mut self) -> Result<(), gst::ErrorMessage> {
// Stream the incoming buffers chunked
- let chunk_stream = self.buffer_rx.take().unwrap().flat_map(move |buffer: gst::Buffer| {
+ let chunk_stream = buffer_rx.flat_map(move |buffer: gst::Buffer| {
async_stream::stream! {
let data = buffer.map_readable().unwrap();
use aws_transcribe::{model::{AudioEvent, AudioStream}, types::Blob};
@@ -129,9 +115,7 @@ impl TranscriberLoop {
}
});
- let settings = self.settings.take().unwrap();
- let mut transcribe_builder = self
- .client
+ let mut transcribe_builder = client
.start_stream_transcription()
.language_code(settings.lang_code)
.media_sample_rate_hertz(settings.sample_rate)
@@ -147,26 +131,42 @@ impl TranscriberLoop {
.vocabulary_filter_method(settings.vocabulary_filter_method);
}
- let mut output = transcribe_builder
+ let output = transcribe_builder
.audio_stream(chunk_stream.into())
.send()
.await
.map_err(|err| {
let err = format!("Transcribe ws init error: {err}");
- gst::error!(CAT, imp: self.imp, "{err}");
+ gst::error!(CAT, imp: imp, "{err}");
gst::error_msg!(gst::LibraryError::Init, ["{err}"])
})?;
- while let Some(event) = output
- .transcript_result_stream
- .recv()
- .await
- .map_err(|err| {
- let err = format!("Transcribe ws stream error: {err}");
- gst::error!(CAT, imp: self.imp, "{err}");
- gst::error_msg!(gst::LibraryError::Failed, ["{err}"])
- })?
- {
+ Ok(TranscriberStream {
+ imp: imp.ref_counted(),
+ output,
+ lateness,
+ partial_index: 0,
+ })
+ }
+
+ pub async fn next(&mut self) -> Result<TranscriptEvent, gst::ErrorMessage> {
+ loop {
+ let event = self
+ .output
+ .transcript_result_stream
+ .recv()
+ .await
+ .map_err(|err| {
+ let err = format!("Transcribe ws stream error: {err}");
+ gst::error!(CAT, imp: self.imp, "{err}");
+ gst::error_msg!(gst::LibraryError::Failed, ["{err}"])
+ })?;
+
+ let Some(event) = event else {
+ gst::debug!(CAT, imp: self.imp, "Transcriber loop sending EOS");
+ return Ok(TranscriptEvent::Eos);
+ };
+
if let model::TranscriptResultStream::TranscriptEvent(transcript_evt) = event {
let mut ready_items = None;
@@ -188,10 +188,7 @@ impl TranscriberLoop {
}
if let Some(ready_items) = ready_items {
- if self.transcript_items_tx.send(ready_items.into()).is_err() {
- gst::debug!(CAT, imp: self.imp, "No transcript items receivers");
- break;
- }
+ return Ok(ready_items.into());
}
} else {
gst::warning!(
@@ -201,13 +198,6 @@ impl TranscriberLoop {
)
}
}
-
- gst::debug!(CAT, imp: self.imp, "Transcriber loop sending EOS");
- let _ = self.transcript_items_tx.send(TranscriptEvent::Eos);
-
- gst::debug!(CAT, imp: self.imp, "Exiting transcriber loop");
-
- Ok(())
}
/// Builds a list from the provided stable items.
diff --git a/net/aws/src/transcriber/translate.rs b/net/aws/src/transcriber/translate.rs
index b944e001..71f43aef 100644
--- a/net/aws/src/transcriber/translate.rs
+++ b/net/aws/src/transcriber/translate.rs
@@ -14,7 +14,7 @@ use aws_sdk_translate as aws_translate;
use futures::channel::mpsc;
use futures::prelude::*;
-use std::collections::VecDeque;
+use std::sync::Arc;
use super::imp::TranslateSrcPad;
use super::transcribe::TranscriptItem;
@@ -40,77 +40,13 @@ impl From<&TranscriptItem> for TranslatedItem {
}
}
-#[derive(Default)]
-pub struct TranslateQueue {
- items: VecDeque<TranscriptItem>,
-}
-
-impl TranslateQueue {
- pub fn is_empty(&self) -> bool {
- self.items.is_empty()
- }
-
- /// Pushes the provided item.
- ///
- /// Returns `Some(..)` if items are ready for translation.
- pub fn push(&mut self, transcript_item: &TranscriptItem) -> Option<Vec<TranscriptItem>> {
- // Keep track of the item individually so we can schedule translation precisely.
- self.items.push_back(transcript_item.clone());
-
- if transcript_item.is_punctuation {
- // This makes it a good chunk for translation.
- // Concatenate as a single item for translation
-
- return Some(self.items.drain(..).collect());
- }
-
- // Regular case: no separator detected, don't push transcript items
- // to translation now. They will be pushed either if a punctuation
- // is found or of a `dequeue()` is requested.
-
- None
- }
-
- /// Dequeues items from the specified `deadline` up to `lookahead`.
- ///
- /// Returns `Some(..)` if some items match the criteria.
- pub fn dequeue(
- &mut self,
- latency: gst::ClockTime,
- threshold: gst::ClockTime,
- lookahead: gst::ClockTime,
- ) -> Option<Vec<TranscriptItem>> {
- let first_pts = self.items.front()?.pts;
- if first_pts + latency > threshold {
- // First item is too early to be sent to translation now
- // we can wait for more items to accumulate.
- return None;
- }
-
- // Can't wait any longer to send the first item to translation
- // Try to get up to lookahead worth of items to improve translation accuracy
- let limit = first_pts + lookahead;
-
- let mut items_acc = vec![self.items.pop_front().unwrap()];
- while let Some(item) = self.items.front() {
- if item.pts > limit {
- break;
- }
-
- items_acc.push(self.items.pop_front().unwrap());
- }
-
- Some(items_acc)
- }
-}
-
pub struct TranslateLoop {
pad: glib::subclass::ObjectImplRef<TranslateSrcPad>,
client: aws_translate::Client,
input_lang: String,
output_lang: String,
tokenization_method: TranslationTokenizationMethod,
- transcript_rx: mpsc::Receiver<Vec<TranscriptItem>>,
+ transcript_rx: mpsc::Receiver<Arc<Vec<TranscriptItem>>>,
translate_tx: mpsc::Sender<Vec<TranslatedItem>>,
}
@@ -121,7 +57,7 @@ impl TranslateLoop {
input_lang: &str,
output_lang: &str,
tokenization_method: TranslationTokenizationMethod,
- transcript_rx: mpsc::Receiver<Vec<TranscriptItem>>,
+ transcript_rx: mpsc::Receiver<Arc<Vec<TranscriptItem>>>,
translate_tx: mpsc::Sender<Vec<TranslatedItem>>,
) -> Self {
let aws_config = imp.aws_config.lock().unwrap();
@@ -175,12 +111,12 @@ impl TranslateLoop {
let (ts_duration_list, content): (Vec<(gst::ClockTime, gst::ClockTime)>, String) =
transcript_items
- .into_iter()
+ .iter()
.map(|item| {
(
(item.pts, item.duration),
match self.tokenization_method {
- Tokenization::None => item.content,
+ Tokenization::None => item.content.clone(),
Tokenization::SpanBased => {
format!("{SPAN_START}{}{SPAN_END}", item.content)
}