From 2b32d00589c23d3914f1b825074d96532df5d930 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Laignel?= Date: Thu, 16 Mar 2023 18:20:08 +0100 Subject: net/aws/transcriber: use two queues for sending transcript items * A queue dedicated to transcript items not intended for translation. * A queue dedicated to transcript items intended for translation. The items are enqueued after a separator is detected or translate-lookahead was reached. Part-of: --- net/aws/src/transcriber/imp.rs | 520 +++++++++++++++++++++------------- net/aws/src/transcriber/transcribe.rs | 104 +++---- net/aws/src/transcriber/translate.rs | 74 +---- 3 files changed, 374 insertions(+), 324 deletions(-) (limited to 'net/aws') diff --git a/net/aws/src/transcriber/imp.rs b/net/aws/src/transcriber/imp.rs index 0c99c760..cc3f8ad2 100644 --- a/net/aws/src/transcriber/imp.rs +++ b/net/aws/src/transcriber/imp.rs @@ -29,12 +29,12 @@ use futures::prelude::*; use tokio::{runtime, sync::broadcast, task}; use std::collections::{BTreeSet, VecDeque}; -use std::sync::Mutex; +use std::sync::{Arc, Mutex}; use once_cell::sync::Lazy; -use super::transcribe::{TranscriberLoop, TranscriptEvent, TranscriptItem, TranscriptionSettings}; -use super::translate::{TranslateLoop, TranslateQueue, TranslatedItem}; +use super::transcribe::{TranscriberSettings, TranscriberStream, TranscriptEvent, TranscriptItem}; +use super::translate::{TranslateLoop, TranslatedItem}; use super::{ AwsTranscriberResultStability, AwsTranscriberVocabularyFilterMethod, TranslationTokenizationMethod, CAT, @@ -148,6 +148,7 @@ struct State { srcpads: BTreeSet, pad_serial: u32, seqnum: gst::Seqnum, + start_time: Option, } impl Default for State { @@ -158,6 +159,7 @@ impl Default for State { srcpads: Default::default(), pad_serial: 0, seqnum: gst::Seqnum::next(), + start_time: None, } } } @@ -168,7 +170,9 @@ pub struct Transcriber { settings: Mutex, state: Mutex, pub(super) aws_config: Mutex>, - // sender to broadcast transcript items to the translate src pads. + // sender to broadcast transcript items to the src pads for translation. + transcript_event_for_translate_tx: broadcast::Sender, + // sender to broadcast transcript items to the src pads, not intended for translation. transcript_event_tx: broadcast::Sender, } @@ -276,7 +280,10 @@ impl Transcriber { ) -> Result { gst::log!(CAT, obj: pad, "Handling {buffer:?}"); - self.ensure_connection(); + self.ensure_connection().map_err(|err| { + gst::element_imp_error!(self, gst::StreamError::Failed, ["Streaming failed: {err}"]); + gst::FlowError::Error + })?; let Some(mut buffer_tx) = self.state.lock().unwrap().buffer_tx.take() else { gst::log!(CAT, obj: pad, "Flushing"); @@ -292,12 +299,82 @@ impl Transcriber { Ok(gst::FlowSuccess::Ok) } +} + +#[derive(Default)] +struct TranslateQueue { + items: VecDeque, +} + +impl TranslateQueue { + fn is_empty(&self) -> bool { + self.items.is_empty() + } + + /// Pushes the provided item. + /// + /// Returns `Some(..)` if items are ready for translation. + fn push(&mut self, transcript_item: &TranscriptItem) -> Option> { + // Keep track of the item individually so we can schedule translation precisely. + self.items.push_back(transcript_item.clone()); + + if transcript_item.is_punctuation { + // This makes it a good chunk for translation. + // Concatenate as a single item for translation + + return Some(self.items.drain(..).collect()); + } + + // Regular case: no separator detected, don't push transcript items + // to translation now. They will be pushed either if a punctuation + // is found or of a `dequeue()` is requested. + + None + } + + /// Dequeues items from the specified `deadline` up to `lookahead`. + /// + /// Returns `Some(..)` if some items match the criteria. + fn dequeue( + &mut self, + latency: gst::ClockTime, + threshold: gst::ClockTime, + lookahead: gst::ClockTime, + ) -> Option> { + let first_pts = self.items.front()?.pts; + if first_pts + latency > threshold { + // First item is too early to be sent to translation now + // we can wait for more items to accumulate. + return None; + } + + // Can't wait any longer to send the first item to translation + // Try to get up to lookahead worth of items to improve translation accuracy + let limit = first_pts + lookahead; + + let mut items_acc = vec![self.items.pop_front().unwrap()]; + while let Some(item) = self.items.front() { + if item.pts > limit { + break; + } + + items_acc.push(self.items.pop_front().unwrap()); + } - fn ensure_connection(&self) { + Some(items_acc) + } + + fn drain(&mut self) -> impl Iterator + '_ { + self.items.drain(..) + } +} + +impl Transcriber { + fn ensure_connection(&self) -> Result<(), gst::ErrorMessage> { let mut state = self.state.lock().unwrap(); if state.buffer_tx.is_some() { - return; + return Ok(()); } let settings = self.settings.lock().unwrap(); @@ -306,21 +383,116 @@ impl Transcriber { let s = in_caps.structure(0).unwrap(); let sample_rate = s.get::("rate").unwrap(); - let transcription_settings = TranscriptionSettings::from(&settings, sample_rate); + let transcription_settings = TranscriberSettings::from(&settings, sample_rate); let (buffer_tx, buffer_rx) = mpsc::channel(1); - let transcriber_loop = TranscriberLoop::new( + let _enter = RUNTIME.enter(); + let mut transcriber_stream = futures::executor::block_on(TranscriberStream::try_new( self, transcription_settings, settings.lateness, buffer_rx, - self.transcript_event_tx.clone(), - ); - let transcriber_loop_handle = RUNTIME.spawn(transcriber_loop.run()); + ))?; + + // Latency budget for an item to be pushed to stream on time + // Margin: + // - 2 * GRANULARITY: to make sure we don't push items up to GRANULARITY late. + // - 1 * GRANULARITY: extra margin to account for additional overheads. + let latency = settings.transcribe_latency.saturating_sub(3 * GRANULARITY); + let translate_lookahead = settings.translate_lookahead; + let mut translate_queue = TranslateQueue::default(); + let imp = self.ref_counted(); + let transcriber_loop_handle = RUNTIME.spawn(async move { + loop { + // This is to make sure we send items on a timely basis or at least Gap events. + let timeout = tokio::time::sleep(GRANULARITY.into()).fuse(); + futures::pin_mut!(timeout); + + let transcriber_next = transcriber_stream.next().fuse(); + futures::pin_mut!(transcriber_next); + + // `transcriber_next` takes precedence over `timeout` + // because we don't want to loose any incoming items. + let res = futures::select_biased! { + event = transcriber_next => Some(event?), + _ = timeout => None, + }; + + use TranscriptEvent::*; + match res { + None => (), + Some(Items(items)) => { + if imp.transcript_event_tx.receiver_count() > 0 { + let _ = imp.transcript_event_tx.send(Items(items.clone())); + } + + if imp.transcript_event_for_translate_tx.receiver_count() > 0 { + for item in items.iter() { + if let Some(items_to_translate) = translate_queue.push(item) { + let _ = imp + .transcript_event_for_translate_tx + .send(Items(items_to_translate.into())); + } + } + } + } + Some(Eos) => { + gst::debug!(CAT, imp: imp, "Transcriber loop sending EOS"); + + if imp.transcript_event_tx.receiver_count() > 0 { + let _ = imp.transcript_event_tx.send(Eos); + } + + if imp.transcript_event_for_translate_tx.receiver_count() > 0 { + let items_to_translate: Vec = + translate_queue.drain().collect(); + let _ = imp + .transcript_event_for_translate_tx + .send(Items(items_to_translate.into())); + + let _ = imp.transcript_event_for_translate_tx.send(Eos); + } + + break; + } + } + + if imp.transcript_event_for_translate_tx.receiver_count() > 0 { + // Check if we need to push items for translation + + let Some((start_time, now)) = imp.get_start_time_and_now() else { + continue; + }; + + if !translate_queue.is_empty() { + let threshold = now - start_time; + + if let Some(items_to_translate) = + translate_queue.dequeue(latency, threshold, translate_lookahead) + { + gst::debug!( + CAT, + imp: imp, + "Forcing to translation (threshold {threshold}): {items_to_translate:?}" + ); + let _ = imp + .transcript_event_for_translate_tx + .send(Items(items_to_translate.into())); + } + } + } + } + + gst::debug!(CAT, imp: imp, "Exiting transcriber loop"); + + Ok(()) + }); state.transcriber_loop_handle = Some(transcriber_loop_handle); state.buffer_tx = Some(buffer_tx); + + Ok(()) } fn prepare(&self) -> Result<(), gst::ErrorMessage> { @@ -382,6 +554,18 @@ impl Transcriber { } gst::info!(CAT, imp: self, "Unprepared"); } + + fn get_start_time_and_now(&self) -> Option<(gst::ClockTime, gst::ClockTime)> { + let now = self.obj().current_running_time()?; + + let mut state = self.state.lock().unwrap(); + + if state.start_time.is_none() { + state.start_time = Some(now); + } + + Some((state.start_time.unwrap(), now)) + } } #[glib::object_subclass] @@ -438,6 +622,7 @@ impl ObjectSubclass for Transcriber { // Setting the channel capacity so that a TranslateSrcPad that would lag // behind for some reasons get a chance to catch-up without loosing items. // Receiver will be created by subscribing to sender later. + let (transcript_event_for_translate_tx, _) = broadcast::channel(128); let (transcript_event_tx, _) = broadcast::channel(128); Self { @@ -446,6 +631,7 @@ impl ObjectSubclass for Transcriber { settings: Default::default(), state: Default::default(), aws_config: Default::default(), + transcript_event_for_translate_tx, transcript_event_tx, } } @@ -876,51 +1062,93 @@ struct TranslationPadTask { elem: super::Transcriber, transcript_event_rx: broadcast::Receiver, needs_translate: bool, - translate_queue: TranslateQueue, translate_loop_handle: Option>>, - to_translate_tx: Option>>, + to_translate_tx: Option>>>, from_translate_rx: Option>>, - translate_latency: gst::ClockTime, - translate_lookahead: gst::ClockTime, send_events: bool, output_items: VecDeque, our_latency: gst::ClockTime, seqnum: gst::Seqnum, send_eos: bool, pending_translations: usize, - start_time: Option, } impl TranslationPadTask { - fn try_new( + async fn try_new( pad: &TranslateSrcPad, elem: super::Transcriber, - transcript_event_rx: broadcast::Receiver, ) -> Result { - let mut this = TranslationPadTask { + let mut translation_loop = None; + let mut translate_loop_handle = None; + let mut to_translate_tx = None; + let mut from_translate_rx = None; + + let (our_latency, transcript_event_rx, needs_translate); + + { + let elem_imp = elem.imp(); + let elem_settings = elem_imp.settings.lock().unwrap(); + + let pad_settings = pad.settings.lock().unwrap(); + + our_latency = TranslateSrcPad::our_latency(&elem_settings, &pad_settings); + if our_latency + elem_settings.lateness <= 2 * GRANULARITY { + let err = format!( + "total latency + lateness must be greater than {}", + 2 * GRANULARITY + ); + gst::error!(CAT, imp: pad, "{err}"); + return Err(gst::error_msg!(gst::LibraryError::Settings, ["{err}"])); + } + + needs_translate = TranslateSrcPad::needs_translation( + &elem_settings.language_code, + pad_settings.language_code.as_deref(), + ); + + if needs_translate { + let (to_loop_tx, to_loop_rx) = mpsc::channel(64); + let (from_loop_tx, from_loop_rx) = mpsc::channel(64); + + translation_loop = Some(TranslateLoop::new( + elem_imp, + pad, + &elem_settings.language_code, + pad_settings.language_code.as_deref().unwrap(), + pad_settings.tokenization_method, + to_loop_rx, + from_loop_tx, + )); + + to_translate_tx = Some(to_loop_tx); + from_translate_rx = Some(from_loop_rx); + + transcript_event_rx = elem_imp.transcript_event_for_translate_tx.subscribe(); + } else { + transcript_event_rx = elem_imp.transcript_event_tx.subscribe(); + } + } + + if let Some(translation_loop) = translation_loop { + translation_loop.check_language().await?; + translate_loop_handle = Some(RUNTIME.spawn(translation_loop.run())); + } + + Ok(TranslationPadTask { pad: pad.ref_counted(), elem, transcript_event_rx, - needs_translate: false, - translate_queue: TranslateQueue::default(), - translate_loop_handle: None, - to_translate_tx: None, - from_translate_rx: None, - translate_latency: DEFAULT_TRANSLATE_LATENCY, - translate_lookahead: DEFAULT_TRANSLATE_LOOKAHEAD, + needs_translate, + translate_loop_handle, + to_translate_tx, + from_translate_rx, send_events: true, output_items: VecDeque::new(), - our_latency: DEFAULT_TRANSCRIBE_LATENCY, + our_latency, seqnum: gst::Seqnum::next(), send_eos: false, pending_translations: 0, - start_time: None, - }; - - let _enter_guard = RUNTIME.enter(); - futures::executor::block_on(this.init_translate())?; - - Ok(this) + }) } } @@ -958,11 +1186,9 @@ impl TranslationPadTask { let transcript_event_rx = self.transcript_event_rx.recv().fuse(); futures::pin_mut!(transcript_event_rx); - // `timeout` takes precedence over `transcript_events` reception - // because we may need to `dequeue` `items` or push a `Gap` event - // before current latency budget is exhausted. + // `transcript_event_rx` takes precedence over `timeout` + // because we don't want to loose any incoming items. futures::select_biased! { - _ = timeout => (), items_res = transcript_event_rx => { use TranscriptEvent::*; use broadcast::error::RecvError; @@ -983,6 +1209,7 @@ impl TranslationPadTask { } } } + _ = timeout => (), } Ok(()) @@ -999,121 +1226,100 @@ impl TranslationPadTask { return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"])); } - let transcript_items = { + let items_to_translate = { // This is to make sure we send items on a timely basis or at least Gap events. let timeout = tokio::time::sleep(GRANULARITY.into()).fuse(); futures::pin_mut!(timeout); - let from_translate_rx = self - .from_translate_rx - .as_mut() - .expect("from_translation chan must be available in translation mode"); - let transcript_event_rx = self.transcript_event_rx.recv().fuse(); futures::pin_mut!(transcript_event_rx); - // `timeout` takes precedence over `transcript_events` reception - // because we may need to `dequeue` `items` or push a `Gap` event - // before current latency budget is exhausted. + // `transcript_event_rx` takes precedence over `timeout` + // because we don't want to loose any incoming items. futures::select_biased! { - _ = timeout => return Ok(()), - translated_items = from_translate_rx.next() => { - let Some(translated_items) = translated_items else { - const ERR: &str = "translation chan terminated"; - gst::debug!(CAT, imp: self.pad, "{ERR}"); - return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"])); - }; - - self.output_items.extend(translated_items.into_iter().map(Into::into)); - self.pending_translations = self.pending_translations.saturating_sub(1); - - return Ok(()); - } items_res = transcript_event_rx => { use TranscriptEvent::*; use broadcast::error::RecvError; match items_res { - Ok(Items(transcript_items)) => transcript_items, + Ok(Items(items_to_translate)) => Some(items_to_translate), Ok(Eos) => { gst::debug!(CAT, imp: self.pad, "Got eos"); self.send_eos = true; - return Ok(()); + None } Err(RecvError::Lagged(nb_msg)) => { gst::warning!(CAT, imp: self.pad, "Missed {nb_msg} transcript sets"); - return Ok(()); + None } Err(RecvError::Closed) => { gst::debug!(CAT, imp: self.pad, "Transcript chan terminated: setting eos"); self.send_eos = true; - return Ok(()); + None } } } + _ = timeout => None, } }; - for items in transcript_items.iter() { - if let Some(items_to_translate) = self.translate_queue.push(items) { - self.send_for_translation(items_to_translate).await?; + if let Some(items_to_translate) = items_to_translate { + if !items_to_translate.is_empty() { + let res = self + .to_translate_tx + .as_mut() + .expect("to_translation chan must be available in translation mode") + .send(items_to_translate) + .await; + + if res.is_err() { + const ERR: &str = "to_translation chan terminated"; + gst::debug!(CAT, imp: self.pad, "{ERR}"); + return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"])); + } + + self.pending_translations += 1; } } - Ok(()) - } + // Check pending translated items + let from_translate_rx = self + .from_translate_rx + .as_mut() + .expect("from_translation chan must be available in translation mode"); - async fn dequeue_for_translation( - &mut self, - start_time: gst::ClockTime, - now: gst::ClockTime, - ) -> Result<(), gst::ErrorMessage> { - if !self.translate_queue.is_empty() { - // Latency budget for an item to be pushed to stream on time - // Margin: - // - 2 * GRANULARITY: to make sure we don't push items up to GRANULARITY late. - // - 1 * GRANULARITY: extra margin to account for additional overheads. - let latency = self.our_latency.saturating_sub(3 * GRANULARITY); - - // Estimated time of arrival for an item sent to translation now. - // (in transcript item ts base) - let translation_eta = now + self.translate_latency - start_time; - - if let Some(items_to_translate) = - self.translate_queue - .dequeue(latency, translation_eta, self.translate_lookahead) - { - gst::debug!(CAT, imp: self.pad, "Forcing to translation: {items_to_translate:?}"); - self.send_for_translation(items_to_translate).await?; - } + while let Ok(translated_items) = from_translate_rx.try_next() { + let Some(translated_items) = translated_items else { + const ERR: &str = "translation chan terminated"; + gst::debug!(CAT, imp: self.pad, "{ERR}"); + return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"])); + }; + + self.output_items + .extend(translated_items.into_iter().map(Into::into)); + self.pending_translations = self.pending_translations.saturating_sub(1); } Ok(()) } async fn dequeue(&mut self) -> bool { - let (now, start_time, mut last_position, mut discont_pending); - { - let mut pad_state = self.pad.state.lock().unwrap(); - - let Some(cur_rt) = self.elem.current_running_time() else { - // Wait for the clock to be available - return true; - }; - now = cur_rt; + let Some((start_time, now)) = self.elem.imp().get_start_time_and_now() else { + // Wait for the clock to be available + return true; + }; - if self.start_time.is_none() { - self.start_time = Some(now); - pad_state.out_segment.set_position(now); - } + let (mut last_position, mut discont_pending) = { + let mut state = self.pad.state.lock().unwrap(); - start_time = self.start_time.unwrap(); - last_position = pad_state.out_segment.position().unwrap(); - discont_pending = pad_state.discont_pending; - } + let last_position = if let Some(pos) = state.out_segment.position() { + pos + } else { + state.out_segment.set_position(start_time); + start_time + }; - if self.needs_translate && self.dequeue_for_translation(start_time, now).await.is_err() { - return false; - } + (last_position, state.discont_pending) + }; /* First, check our pending buffers */ while let Some(item) = self.output_items.front() { @@ -1206,11 +1412,7 @@ impl TranslationPadTask { } } - if self.send_eos - && self.pending_translations == 0 - && self.output_items.is_empty() - && self.translate_queue.is_empty() - { + if self.send_eos && self.pending_translations == 0 && self.output_items.is_empty() { /* We're EOS, we can pause and exit early */ let _ = self.pad.obj().pause_task(); @@ -1261,28 +1463,6 @@ impl TranslationPadTask { true } - async fn send_for_translation( - &mut self, - transcript_items: Vec, - ) -> Result<(), gst::ErrorMessage> { - let res = self - .to_translate_tx - .as_mut() - .expect("to_translation chan must be available in translation mode") - .send(transcript_items) - .await; - - if res.is_err() { - const ERR: &str = "to_translation chan terminated"; - gst::debug!(CAT, imp: self.pad, "{ERR}"); - return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"])); - } - - self.pending_translations += 1; - - Ok(()) - } - fn ensure_init_events(&mut self) -> Result<(), gst::ErrorMessage> { if !self.send_events { return Ok(()); @@ -1332,62 +1512,6 @@ impl TranslationPadTask { } } -impl TranslationPadTask { - async fn init_translate(&mut self) -> Result<(), gst::ErrorMessage> { - let mut translation_loop = None; - - { - let elem_imp = self.elem.imp(); - let elem_settings = elem_imp.settings.lock().unwrap(); - - let pad_settings = self.pad.settings.lock().unwrap(); - - self.our_latency = TranslateSrcPad::our_latency(&elem_settings, &pad_settings); - if self.our_latency + elem_settings.lateness <= 2 * GRANULARITY { - let err = format!( - "total latency + lateness must be greater than {}", - 2 * GRANULARITY - ); - gst::error!(CAT, imp: self.pad, "{err}"); - return Err(gst::error_msg!(gst::LibraryError::Settings, ["{err}"])); - } - - self.translate_latency = elem_settings.translate_latency; - self.translate_lookahead = elem_settings.translate_lookahead; - - self.needs_translate = TranslateSrcPad::needs_translation( - &elem_settings.language_code, - pad_settings.language_code.as_deref(), - ); - - if self.needs_translate { - let (to_translate_tx, to_translate_rx) = mpsc::channel(64); - let (from_translate_tx, from_translate_rx) = mpsc::channel(64); - - translation_loop = Some(TranslateLoop::new( - elem_imp, - &self.pad, - &elem_settings.language_code, - pad_settings.language_code.as_deref().unwrap(), - pad_settings.tokenization_method, - to_translate_rx, - from_translate_tx, - )); - - self.to_translate_tx = Some(to_translate_tx); - self.from_translate_rx = Some(from_translate_rx); - } - } - - if let Some(translation_loop) = translation_loop { - translation_loop.check_language().await?; - self.translate_loop_handle = Some(RUNTIME.spawn(translation_loop.run())); - } - - Ok(()) - } -} - #[derive(Debug)] struct TranslationPadState { discont_pending: bool, @@ -1422,8 +1546,8 @@ impl TranslateSrcPad { gst::debug!(CAT, imp: self, "Starting task"); let elem = self.parent(); - let transcript_event_rx = elem.imp().transcript_event_tx.subscribe(); - let mut pad_task = TranslationPadTask::try_new(self, elem, transcript_event_rx) + let _enter = RUNTIME.enter(); + let mut pad_task = futures::executor::block_on(TranslationPadTask::try_new(self, elem)) .map_err(|err| gst::loggable_error!(CAT, format!("Failed to start pad task {err}")))?; let imp = self.ref_counted(); diff --git a/net/aws/src/transcriber/transcribe.rs b/net/aws/src/transcriber/transcribe.rs index 97301380..d094d87a 100644 --- a/net/aws/src/transcriber/transcribe.rs +++ b/net/aws/src/transcriber/transcribe.rs @@ -15,7 +15,6 @@ use aws_sdk_transcribestreaming::model; use futures::channel::mpsc; use futures::prelude::*; -use tokio::sync::broadcast; use std::sync::Arc; @@ -23,7 +22,7 @@ use super::imp::{Settings, Transcriber}; use super::CAT; #[derive(Debug)] -pub struct TranscriptionSettings { +pub struct TranscriberSettings { lang_code: model::LanguageCode, sample_rate: i32, vocabulary: Option, @@ -33,9 +32,9 @@ pub struct TranscriptionSettings { results_stability: model::PartialResultsStability, } -impl TranscriptionSettings { +impl TranscriberSettings { pub(super) fn from(settings: &Settings, sample_rate: i32) -> Self { - TranscriptionSettings { + TranscriberSettings { lang_code: settings.language_code.as_str().into(), sample_rate, vocabulary: settings.vocabulary.clone(), @@ -83,43 +82,30 @@ impl From> for TranscriptEvent { } } -pub struct TranscriberLoop { +pub struct TranscriberStream { imp: glib::subclass::ObjectImplRef, - client: aws_transcribe::Client, - settings: Option, + output: aws_transcribe::output::StartStreamTranscriptionOutput, lateness: gst::ClockTime, - buffer_rx: Option>, - transcript_items_tx: broadcast::Sender, partial_index: usize, } -impl TranscriberLoop { - pub fn new( +impl TranscriberStream { + pub async fn try_new( imp: &Transcriber, - settings: TranscriptionSettings, + settings: TranscriberSettings, lateness: gst::ClockTime, buffer_rx: mpsc::Receiver, - transcript_items_tx: broadcast::Sender, - ) -> Self { - let aws_config = imp.aws_config.lock().unwrap(); - let aws_config = aws_config - .as_ref() - .expect("aws_config must be initialized at this stage"); - - TranscriberLoop { - imp: imp.ref_counted(), - client: aws_transcribe::Client::new(aws_config), - settings: Some(settings), - lateness, - buffer_rx: Some(buffer_rx), - transcript_items_tx, - partial_index: 0, - } - } + ) -> Result { + let client = { + let aws_config = imp.aws_config.lock().unwrap(); + let aws_config = aws_config + .as_ref() + .expect("aws_config must be initialized at this stage"); + aws_transcribe::Client::new(aws_config) + }; - pub async fn run(mut self) -> Result<(), gst::ErrorMessage> { // Stream the incoming buffers chunked - let chunk_stream = self.buffer_rx.take().unwrap().flat_map(move |buffer: gst::Buffer| { + let chunk_stream = buffer_rx.flat_map(move |buffer: gst::Buffer| { async_stream::stream! { let data = buffer.map_readable().unwrap(); use aws_transcribe::{model::{AudioEvent, AudioStream}, types::Blob}; @@ -129,9 +115,7 @@ impl TranscriberLoop { } }); - let settings = self.settings.take().unwrap(); - let mut transcribe_builder = self - .client + let mut transcribe_builder = client .start_stream_transcription() .language_code(settings.lang_code) .media_sample_rate_hertz(settings.sample_rate) @@ -147,26 +131,42 @@ impl TranscriberLoop { .vocabulary_filter_method(settings.vocabulary_filter_method); } - let mut output = transcribe_builder + let output = transcribe_builder .audio_stream(chunk_stream.into()) .send() .await .map_err(|err| { let err = format!("Transcribe ws init error: {err}"); - gst::error!(CAT, imp: self.imp, "{err}"); + gst::error!(CAT, imp: imp, "{err}"); gst::error_msg!(gst::LibraryError::Init, ["{err}"]) })?; - while let Some(event) = output - .transcript_result_stream - .recv() - .await - .map_err(|err| { - let err = format!("Transcribe ws stream error: {err}"); - gst::error!(CAT, imp: self.imp, "{err}"); - gst::error_msg!(gst::LibraryError::Failed, ["{err}"]) - })? - { + Ok(TranscriberStream { + imp: imp.ref_counted(), + output, + lateness, + partial_index: 0, + }) + } + + pub async fn next(&mut self) -> Result { + loop { + let event = self + .output + .transcript_result_stream + .recv() + .await + .map_err(|err| { + let err = format!("Transcribe ws stream error: {err}"); + gst::error!(CAT, imp: self.imp, "{err}"); + gst::error_msg!(gst::LibraryError::Failed, ["{err}"]) + })?; + + let Some(event) = event else { + gst::debug!(CAT, imp: self.imp, "Transcriber loop sending EOS"); + return Ok(TranscriptEvent::Eos); + }; + if let model::TranscriptResultStream::TranscriptEvent(transcript_evt) = event { let mut ready_items = None; @@ -188,10 +188,7 @@ impl TranscriberLoop { } if let Some(ready_items) = ready_items { - if self.transcript_items_tx.send(ready_items.into()).is_err() { - gst::debug!(CAT, imp: self.imp, "No transcript items receivers"); - break; - } + return Ok(ready_items.into()); } } else { gst::warning!( @@ -201,13 +198,6 @@ impl TranscriberLoop { ) } } - - gst::debug!(CAT, imp: self.imp, "Transcriber loop sending EOS"); - let _ = self.transcript_items_tx.send(TranscriptEvent::Eos); - - gst::debug!(CAT, imp: self.imp, "Exiting transcriber loop"); - - Ok(()) } /// Builds a list from the provided stable items. diff --git a/net/aws/src/transcriber/translate.rs b/net/aws/src/transcriber/translate.rs index b944e001..71f43aef 100644 --- a/net/aws/src/transcriber/translate.rs +++ b/net/aws/src/transcriber/translate.rs @@ -14,7 +14,7 @@ use aws_sdk_translate as aws_translate; use futures::channel::mpsc; use futures::prelude::*; -use std::collections::VecDeque; +use std::sync::Arc; use super::imp::TranslateSrcPad; use super::transcribe::TranscriptItem; @@ -40,77 +40,13 @@ impl From<&TranscriptItem> for TranslatedItem { } } -#[derive(Default)] -pub struct TranslateQueue { - items: VecDeque, -} - -impl TranslateQueue { - pub fn is_empty(&self) -> bool { - self.items.is_empty() - } - - /// Pushes the provided item. - /// - /// Returns `Some(..)` if items are ready for translation. - pub fn push(&mut self, transcript_item: &TranscriptItem) -> Option> { - // Keep track of the item individually so we can schedule translation precisely. - self.items.push_back(transcript_item.clone()); - - if transcript_item.is_punctuation { - // This makes it a good chunk for translation. - // Concatenate as a single item for translation - - return Some(self.items.drain(..).collect()); - } - - // Regular case: no separator detected, don't push transcript items - // to translation now. They will be pushed either if a punctuation - // is found or of a `dequeue()` is requested. - - None - } - - /// Dequeues items from the specified `deadline` up to `lookahead`. - /// - /// Returns `Some(..)` if some items match the criteria. - pub fn dequeue( - &mut self, - latency: gst::ClockTime, - threshold: gst::ClockTime, - lookahead: gst::ClockTime, - ) -> Option> { - let first_pts = self.items.front()?.pts; - if first_pts + latency > threshold { - // First item is too early to be sent to translation now - // we can wait for more items to accumulate. - return None; - } - - // Can't wait any longer to send the first item to translation - // Try to get up to lookahead worth of items to improve translation accuracy - let limit = first_pts + lookahead; - - let mut items_acc = vec![self.items.pop_front().unwrap()]; - while let Some(item) = self.items.front() { - if item.pts > limit { - break; - } - - items_acc.push(self.items.pop_front().unwrap()); - } - - Some(items_acc) - } -} - pub struct TranslateLoop { pad: glib::subclass::ObjectImplRef, client: aws_translate::Client, input_lang: String, output_lang: String, tokenization_method: TranslationTokenizationMethod, - transcript_rx: mpsc::Receiver>, + transcript_rx: mpsc::Receiver>>, translate_tx: mpsc::Sender>, } @@ -121,7 +57,7 @@ impl TranslateLoop { input_lang: &str, output_lang: &str, tokenization_method: TranslationTokenizationMethod, - transcript_rx: mpsc::Receiver>, + transcript_rx: mpsc::Receiver>>, translate_tx: mpsc::Sender>, ) -> Self { let aws_config = imp.aws_config.lock().unwrap(); @@ -175,12 +111,12 @@ impl TranslateLoop { let (ts_duration_list, content): (Vec<(gst::ClockTime, gst::ClockTime)>, String) = transcript_items - .into_iter() + .iter() .map(|item| { ( (item.pts, item.duration), match self.tokenization_method { - Tokenization::None => item.content, + Tokenization::None => item.content.clone(), Tokenization::SpanBased => { format!("{SPAN_START}{}{SPAN_END}", item.content) } -- cgit v1.2.3