Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/sdroege/gst-plugin-rs.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/net/aws
diff options
context:
space:
mode:
authorFrançois Laignel <francois@centricular.com>2023-03-10 16:47:38 +0300
committerGStreamer Marge Bot <gitlab-merge-bot@gstreamer-foundation.org>2023-03-14 16:48:32 +0300
commit743e97738fe44bec8843f467828a1ec2aa710d91 (patch)
treec23da089726cd9560d52d58521c560415a44fb52 /net/aws
parent9a55fda69c1c83ba5aabcc0796f0af9b031fc6da (diff)
net/aws/transcriber: add translation request src pads
This commit adds an optional transcript translation feature implemented as request src Pads. When requesting a src Pad, the user can specify the translation language code using Pad properties 'language-code'. The following properties are defined on the Element: - 'transcribe-latency': formerly 'latency', defines the expected latency for the Transcribe webservice. - 'translate-latency': defines the expected latency for the Translate webservice. - 'transcript-lookahead': maximum transcript duration to send to translation when a transcript is hitting its deadline and no punctuation was found. When the input and output languages are the same, only the 'transcribe-latency' is used for the Pad. Otherwise, the resulting latency is the addition of 'transcribe-latency' and 'translate-latency'. Part-of: <https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs/-/merge_requests/1109>
Diffstat (limited to 'net/aws')
-rw-r--r--net/aws/Cargo.toml1
-rw-r--r--net/aws/src/transcriber/imp.rs1839
-rw-r--r--net/aws/src/transcriber/mod.rs19
-rw-r--r--net/aws/src/transcriber/transcribe.rs277
-rw-r--r--net/aws/src/transcriber/translate.rs215
5 files changed, 1626 insertions, 725 deletions
diff --git a/net/aws/Cargo.toml b/net/aws/Cargo.toml
index c329009f..663263f9 100644
--- a/net/aws/Cargo.toml
+++ b/net/aws/Cargo.toml
@@ -16,6 +16,7 @@ base32 = "0.4"
aws-config = "0.54.0"
aws-sdk-s3 = "0.24.0"
aws-sdk-transcribestreaming = "0.24.0"
+aws-sdk-translate = "0.24.0"
aws-types = "0.54.0"
aws-credential-types = "0.54.0"
aws-sig-auth = "0.54.0"
diff --git a/net/aws/src/transcriber/imp.rs b/net/aws/src/transcriber/imp.rs
index e4116d32..d856e6d4 100644
--- a/net/aws/src/transcriber/imp.rs
+++ b/net/aws/src/transcriber/imp.rs
@@ -7,72 +7,97 @@
//
// SPDX-License-Identifier: MPL-2.0
-use gst::glib;
-use gst::prelude::*;
+//! AWS Transcriber element.
+//!
+//! This element calls AWS Transcribe to extract transcripts from an audio stream.
+//! The element can optionally translate the resulting transcripts to one or
+//! multiple languages.
+//!
+//! This module contains the element implementation as well as the `TranslationSrcPad`
+//! sublcass and its `TranslationPadTask`.
+//!
+//! Web service specific code can be found in the `transcribe` and `translate` modules.
+
use gst::subclass::prelude::*;
+use gst::{glib, prelude::*};
use aws_sdk_transcribestreaming as aws_transcribe;
-use aws_sdk_transcribestreaming::model;
use futures::channel::mpsc;
use futures::future::AbortHandle;
use futures::prelude::*;
-use tokio::{runtime, task};
+use tokio::{runtime, sync::broadcast, task};
-use std::cmp::Ordering;
-use std::collections::VecDeque;
+use std::collections::{BTreeSet, VecDeque};
use std::sync::Mutex;
use once_cell::sync::Lazy;
-use super::{AwsTranscriberResultStability, AwsTranscriberVocabularyFilterMethod};
-
-static CAT: Lazy<gst::DebugCategory> = Lazy::new(|| {
- gst::DebugCategory::new(
- "awstranscribe",
- gst::DebugColorFlags::empty(),
- Some("AWS Transcribe element"),
- )
-});
+use super::transcribe::{TranscriberLoop, TranscriptEvent, TranscriptItem, TranscriptionSettings};
+use super::translate::{TranslatedItem, TranslationLoop, TranslationQueue};
+use super::{AwsTranscriberResultStability, AwsTranscriberVocabularyFilterMethod, CAT};
static RUNTIME: Lazy<runtime::Runtime> = Lazy::new(|| {
runtime::Builder::new_multi_thread()
.enable_all()
- .worker_threads(1)
.build()
.unwrap()
});
const DEFAULT_TRANSCRIBER_REGION: &str = "us-east-1";
-const DEFAULT_LATENCY: gst::ClockTime = gst::ClockTime::from_seconds(8);
+
+// Deprecated in 0.11.0: due to evolutions of the transcriber element,
+// this property has been replaced by `TRANSCRIBE_LATENCY_PROPERTY`.
+const DEPRECATED_LATENCY_PROPERTY: &str = "latency";
+
+const TRANSCRIBE_LATENCY_PROPERTY: &str = "transcribe-latency";
+pub const DEFAULT_TRANSCRIBE_LATENCY: gst::ClockTime = gst::ClockTime::from_seconds(8);
+
+const TRANSLATE_LATENCY_PROPERTY: &str = "translate-latency";
+pub const DEFAULT_TRANSLATE_LATENCY: gst::ClockTime = gst::ClockTime::from_mseconds(500);
+
+const TRANSCRIPT_LOOKAHEAD_PROPERTY: &str = "transcript-lookahead";
+pub const DEFAULT_TRANSCRIPT_LOOKAHEAD: gst::ClockTime = gst::ClockTime::from_seconds(5);
+
const DEFAULT_LATENESS: gst::ClockTime = gst::ClockTime::ZERO;
-const DEFAULT_LANGUAGE_CODE: &str = "en-US";
+pub const DEFAULT_INPUT_LANG_CODE: &str = "en-US";
+
const DEFAULT_STABILITY: AwsTranscriberResultStability = AwsTranscriberResultStability::Low;
const DEFAULT_VOCABULARY_FILTER_METHOD: AwsTranscriberVocabularyFilterMethod =
AwsTranscriberVocabularyFilterMethod::Mask;
-const GRANULARITY: gst::ClockTime = gst::ClockTime::from_mseconds(100);
+
+// The period at which the event loops will check if they need to push
+// anything downstream when no other events show up.
+pub const GRANULARITY: gst::ClockTime = gst::ClockTime::from_mseconds(100);
+
+const OUTPUT_LANG_CODE_PROPERTY: &str = "language-code";
+const DEFAULT_OUTPUT_LANG_CODE: Option<&str> = None;
#[derive(Debug, Clone)]
-struct Settings {
- latency: gst::ClockTime,
+pub(super) struct Settings {
+ transcribe_latency: gst::ClockTime,
+ translate_latency: gst::ClockTime,
+ transcript_lookahead: gst::ClockTime,
lateness: gst::ClockTime,
- language_code: String,
- vocabulary: Option<String>,
- session_id: Option<String>,
- results_stability: AwsTranscriberResultStability,
+ pub language_code: String,
+ pub vocabulary: Option<String>,
+ pub session_id: Option<String>,
+ pub results_stability: AwsTranscriberResultStability,
access_key: Option<String>,
secret_access_key: Option<String>,
session_token: Option<String>,
- vocabulary_filter: Option<String>,
- vocabulary_filter_method: AwsTranscriberVocabularyFilterMethod,
+ pub vocabulary_filter: Option<String>,
+ pub vocabulary_filter_method: AwsTranscriberVocabularyFilterMethod,
}
impl Default for Settings {
fn default() -> Self {
Self {
- latency: DEFAULT_LATENCY,
+ transcribe_latency: DEFAULT_TRANSCRIBE_LATENCY,
+ translate_latency: DEFAULT_TRANSLATE_LATENCY,
+ transcript_lookahead: DEFAULT_TRANSCRIPT_LOOKAHEAD,
lateness: DEFAULT_LATENESS,
- language_code: DEFAULT_LANGUAGE_CODE.to_string(),
+ language_code: DEFAULT_INPUT_LANG_CODE.to_string(),
vocabulary: None,
session_id: None,
results_stability: DEFAULT_STABILITY,
@@ -85,578 +110,72 @@ impl Default for Settings {
}
}
-#[derive(Debug)]
-struct TranscriptionSettings {
- lang_code: model::LanguageCode,
- sample_rate: i32,
- vocabulary: Option<String>,
- vocabulary_filter: Option<String>,
- vocabulary_filter_method: model::VocabularyFilterMethod,
- session_id: Option<String>,
- results_stability: model::PartialResultsStability,
-}
-
-impl TranscriptionSettings {
- fn from(settings: &Settings, sample_rate: i32) -> Self {
- TranscriptionSettings {
- lang_code: settings.language_code.as_str().into(),
- sample_rate,
- vocabulary: settings.vocabulary.clone(),
- vocabulary_filter: settings.vocabulary_filter.clone(),
- vocabulary_filter_method: settings.vocabulary_filter_method.into(),
- session_id: settings.session_id.clone(),
- results_stability: settings.results_stability.into(),
- }
- }
-}
-
-struct TranscriberLoop {
- imp: glib::subclass::ObjectImplRef<Transcriber>,
- client: aws_transcribe::Client,
- settings: TranscriptionSettings,
- lateness: gst::ClockTime,
- buffer_rx: mpsc::Receiver<gst::Buffer>,
- transcript_notif_tx: mpsc::Sender<()>,
-}
-
-impl TranscriberLoop {
- fn new(
- imp: &Transcriber,
- aws_config: &aws_config::SdkConfig,
- settings: TranscriptionSettings,
- lateness: gst::ClockTime,
- buffer_rx: mpsc::Receiver<gst::Buffer>,
- transcript_notif_tx: mpsc::Sender<()>,
- ) -> Self {
- TranscriberLoop {
- imp: imp.ref_counted(),
- client: aws_transcribe::Client::new(aws_config),
- settings,
- lateness,
- buffer_rx,
- transcript_notif_tx,
- }
- }
-
- async fn run(mut self) -> Result<(), gst::ErrorMessage> {
- // Stream the incoming buffers chunked
- let chunk_stream = self.buffer_rx.flat_map(move |buffer: gst::Buffer| {
- async_stream::stream! {
- let data = buffer.map_readable().unwrap();
- use aws_transcribe::{model::{AudioEvent, AudioStream}, types::Blob};
- for chunk in data.chunks(8192) {
- yield Ok(AudioStream::AudioEvent(AudioEvent::builder().audio_chunk(Blob::new(chunk)).build()));
- }
- }
- });
-
- let mut transcribe_builder = self
- .client
- .start_stream_transcription()
- .language_code(self.settings.lang_code)
- .media_sample_rate_hertz(self.settings.sample_rate)
- .media_encoding(model::MediaEncoding::Pcm)
- .enable_partial_results_stabilization(true)
- .partial_results_stability(self.settings.results_stability)
- .set_vocabulary_name(self.settings.vocabulary)
- .set_session_id(self.settings.session_id);
-
- if let Some(vocabulary_filter) = self.settings.vocabulary_filter {
- transcribe_builder = transcribe_builder
- .vocabulary_filter_name(vocabulary_filter)
- .vocabulary_filter_method(self.settings.vocabulary_filter_method);
- }
-
- let mut output = transcribe_builder
- .audio_stream(chunk_stream.into())
- .send()
- .await
- .map_err(|err| {
- let err = format!("Transcribe ws init error: {err}");
- gst::error!(CAT, imp: self.imp, "{err}");
- gst::error_msg!(gst::LibraryError::Init, ["{err}"])
- })?;
-
- while let Some(event) = output
- .transcript_result_stream
- .recv()
- .await
- .map_err(|err| {
- let err = format!("Transcribe ws stream error: {err}");
- gst::error!(CAT, imp: self.imp, "{err}");
- gst::error_msg!(gst::LibraryError::Failed, ["{err}"])
- })?
- {
- if let model::TranscriptResultStream::TranscriptEvent(transcript_evt) = event {
- let mut enqueued = false;
-
- if let Some(result) = transcript_evt
- .transcript
- .and_then(|transcript| transcript.results)
- .and_then(|mut results| results.drain(..).next())
- {
- gst::trace!(CAT, imp: self.imp, "Received: {result:?}");
-
- if let Some(alternative) = result
- .alternatives
- .and_then(|mut alternatives| alternatives.drain(..).next())
- {
- if let Some(items) = alternative.items {
- enqueued = self.imp.enqueue(items, result.is_partial, self.lateness);
- }
- }
- }
-
- if enqueued && self.transcript_notif_tx.send(()).await.is_err() {
- gst::debug!(CAT, imp: self.imp, "Terminated transcript_notif_tx channel");
- break;
- }
- } else {
- gst::warning!(
- CAT,
- imp: self.imp,
- "Transcribe ws returned unknown event: consider upgrading the SDK"
- )
- }
- }
-
- gst::debug!(CAT, imp: self.imp, "Exiting ws loop");
-
- Ok(())
- }
-}
-
-struct State {
- aws_config: Option<aws_config::SdkConfig>,
+pub(super) struct State {
buffer_tx: Option<mpsc::Sender<gst::Buffer>>,
- transcript_notif_tx: Option<mpsc::Sender<()>>,
- ws_loop_handle: Option<task::JoinHandle<Result<(), gst::ErrorMessage>>>,
- task_abort_handle: Option<AbortHandle>,
- in_segment: gst::FormattedSegment<gst::ClockTime>,
- out_segment: gst::FormattedSegment<gst::ClockTime>,
- seqnum: gst::Seqnum,
- buffers: VecDeque<gst::Buffer>,
- send_eos: bool,
- discont: bool,
- partial_index: usize,
- send_events: bool,
- start_time: Option<gst::ClockTime>,
+ transcriber_loop_handle: Option<task::JoinHandle<Result<(), gst::ErrorMessage>>>,
+ srcpads: BTreeSet<super::TranslationSrcPad>,
+ pad_serial: u32,
+ pub seqnum: gst::Seqnum,
}
impl Default for State {
fn default() -> Self {
Self {
- aws_config: None,
buffer_tx: None,
- transcript_notif_tx: None,
- ws_loop_handle: None,
- task_abort_handle: None,
- in_segment: gst::FormattedSegment::new(),
- out_segment: gst::FormattedSegment::new(),
+ transcriber_loop_handle: None,
+ srcpads: Default::default(),
+ pad_serial: 0,
seqnum: gst::Seqnum::next(),
- buffers: VecDeque::new(),
- send_eos: false,
- discont: true,
- partial_index: 0,
- send_events: true,
- start_time: None,
}
}
}
pub struct Transcriber {
- srcpad: gst::Pad,
+ static_srcpad: super::TranslationSrcPad,
sinkpad: gst::Pad,
settings: Mutex<Settings>,
state: Mutex<State>,
+ pub(super) aws_config: Mutex<Option<aws_config::SdkConfig>>,
+ // sender to broadcast transcript items to the translation src pads.
+ transcript_event_tx: broadcast::Sender<TranscriptEvent>,
}
impl Transcriber {
- fn dequeue(&self) -> bool {
- /* First, check our pending buffers */
- let mut items = vec![];
-
- let Some(now) = self.obj().current_running_time() else { return true };
-
- let latency = self.settings.lock().unwrap().latency;
-
- let mut state = self.state.lock().unwrap();
+ fn start_srcpad_tasks(&self, state: &State) -> Result<(), gst::LoggableError> {
+ gst::debug!(CAT, imp: self, "Starting tasks");
- if state.start_time.is_none() {
- state.start_time = Some(now);
- state.out_segment.set_position(now);
+ if self.static_srcpad.is_linked() {
+ self.static_srcpad.imp().start_task()?;
}
- let start_time = state.start_time.unwrap();
- let mut last_position = state.out_segment.position().unwrap();
-
- let send_eos = state.send_eos && state.buffers.is_empty();
-
- while let Some(buf) = state.buffers.front() {
- let pts = buf.pts().unwrap();
- gst::trace!(
- CAT,
- imp: self,
- "Checking now {now} if item is ready for dequeuing, PTS {pts}, threshold {} vs {}",
- pts + latency.saturating_sub(3 * GRANULARITY),
- now - start_time
- );
-
- if pts + latency.saturating_sub(3 * GRANULARITY) < now - start_time {
- /* Safe unwrap, we know we have an item */
- let mut buf = state.buffers.pop_front().unwrap();
-
- {
- let buf_mut = buf.get_mut().unwrap();
-
- buf_mut.set_pts(start_time + pts);
- }
-
- items.push(buf);
- } else {
- break;
- }
+ for pad in state.srcpads.iter() {
+ pad.imp().start_task()?;
}
- let seqnum = state.seqnum;
-
- drop(state);
-
- /* We're EOS, we can pause and exit early */
- if send_eos {
- let _ = self.srcpad.pause_task();
-
- return self
- .srcpad
- .push_event(gst::event::Eos::builder().seqnum(seqnum).build());
- }
-
- for mut buf in items.drain(..) {
- let mut pts = buf.pts().unwrap();
- let mut duration = buf.duration().unwrap();
-
- match pts.cmp(&last_position) {
- Ordering::Greater => {
- let gap_event = gst::event::Gap::builder(last_position)
- .duration(pts - last_position)
- .seqnum(seqnum)
- .build();
- gst::log!(CAT, "Pushing gap: {last_position} -> {pts}");
- if !self.srcpad.push_event(gap_event) {
- return false;
- }
- }
- Ordering::Less => {
- let delta = last_position - pts;
-
- gst::warning!(
- CAT,
- imp: self,
- "Updating item PTS ({pts} < {last_position}), consider increasing latency",
- );
-
- pts = last_position;
- duration = duration.saturating_sub(delta);
-
- {
- let buf_mut = buf.get_mut().unwrap();
-
- buf_mut.set_pts(pts);
- buf_mut.set_duration(duration);
- }
- }
- _ => (),
- }
-
- last_position = pts + duration;
-
- gst::debug!(CAT, "Pushing buffer: {pts} -> {}", pts + duration);
-
- if self.srcpad.push(buf).is_err() {
- return false;
- }
- }
-
- /* next, push a gap if we're lagging behind the target position */
- gst::trace!(
- CAT,
- imp: self,
- "Checking now: {now} if we need to push a gap, last_position: {last_position}, threshold: {}",
- last_position + latency.saturating_sub(GRANULARITY)
- );
-
- if now > last_position + latency.saturating_sub(GRANULARITY) {
- let duration = now - last_position - latency.saturating_sub(GRANULARITY);
-
- let gap_event = gst::event::Gap::builder(last_position)
- .duration(duration)
- .seqnum(seqnum)
- .build();
-
- gst::log!(
- CAT,
- "Pushing gap: {last_position} -> {}",
- last_position + duration
- );
-
- last_position += duration;
-
- if !self.srcpad.push_event(gap_event) {
- return false;
- }
- }
-
- self.state
- .lock()
- .unwrap()
- .out_segment
- .set_position(last_position);
-
- true
- }
-
- /// Enqueues a buffer for each of the provided stable items.
- ///
- /// Returns `true` if at least one buffer was enqueued.
- fn enqueue(
- &self,
- mut items: Vec<model::Item>,
- partial: bool,
- lateness: gst::ClockTime,
- ) -> bool {
- let mut state = self.state.lock().unwrap();
-
- if items.len() <= state.partial_index {
- gst::error!(
- CAT,
- imp: self,
- "sanity check failed, alternative length {} < partial_index {}",
- items.len(),
- state.partial_index
- );
-
- if !partial {
- state.partial_index = 0;
- }
-
- return false;
- }
-
- let mut enqueued = false;
-
- for item in items.drain(state.partial_index..) {
- if !item.stable().unwrap_or(false) {
- break;
- }
-
- let Some(content) = item.content else { continue };
-
- let start_time = ((item.start_time * 1_000_000_000.0) as u64).nseconds() + lateness;
- let end_time = ((item.end_time * 1_000_000_000.0) as u64).nseconds() + lateness;
-
- /* Should be sent now */
- gst::debug!(
- CAT,
- imp: self,
- "Item is ready for queuing: {content}, PTS {start_time}",
- );
-
- let mut buf = gst::Buffer::from_mut_slice(content.into_bytes());
- {
- let buf = buf.get_mut().unwrap();
-
- if state.discont {
- buf.set_flags(gst::BufferFlags::DISCONT);
- state.discont = false;
- }
-
- buf.set_pts(start_time);
- buf.set_duration(end_time - start_time);
- }
-
- state.partial_index += 1;
-
- state.buffers.push_back(buf);
- enqueued = true;
- }
-
- if !partial {
- state.partial_index = 0;
- }
-
- enqueued
- }
-
- fn pad_loop_fn(&self, transcript_notif_rx: &mut mpsc::Receiver<()>) {
- let mut events = {
- let mut events = vec![];
-
- let state = self.state.lock().unwrap();
- if state.send_events {
- events.push(
- gst::event::StreamStart::builder("transcription")
- .seqnum(state.seqnum)
- .build(),
- );
-
- let caps = gst::Caps::builder("text/x-raw")
- .field("format", "utf8")
- .build();
- events.push(
- gst::event::Caps::builder(&caps)
- .seqnum(state.seqnum)
- .build(),
- );
-
- events.push(
- gst::event::Segment::builder(&state.out_segment)
- .seqnum(state.seqnum)
- .build(),
- );
- }
-
- events
- };
-
- if !events.is_empty() {
- for event in events.drain(..) {
- gst::info!(CAT, imp: self, "Sending {event:?}");
- self.srcpad.push_event(event);
- }
-
- self.state.lock().unwrap().send_events = false;
- }
-
- let future = async move {
- let timeout = tokio::time::sleep(GRANULARITY.into()).fuse();
- futures::pin_mut!(timeout);
-
- futures::select! {
- notif = transcript_notif_rx.next() => {
- if notif.is_none() {
- // Transcriber loop terminated
- self.state.lock().unwrap().send_eos = true;
- };
- }
- _ = timeout => (),
- };
-
- if !self.dequeue() {
- gst::info!(CAT, imp: self, "Failed to dequeue buffer, pausing");
- let _ = self.srcpad.pause_task();
- }
- };
-
- let (abortable_future, abort_handle) = future::abortable(future);
- self.state.lock().unwrap().task_abort_handle = Some(abort_handle);
-
- let _enter = RUNTIME.enter();
- if futures::executor::block_on(abortable_future).is_err() {
- gst::debug!(CAT, imp: self, "task iter aborted");
- }
- }
-
- fn start_task(&self) -> Result<(), gst::LoggableError> {
- gst::debug!(CAT, imp: self, "Starting task");
- let mut state = self.state.lock().unwrap();
-
- let (transcript_notif_tx, mut transcript_notif_rx) = mpsc::channel(1);
-
- let imp = self.ref_counted();
- let res = self
- .srcpad
- .start_task(move || imp.pad_loop_fn(&mut transcript_notif_rx));
-
- if res.is_err() {
- state.transcript_notif_tx = None;
- return Err(gst::loggable_error!(CAT, "Failed to start pad task"));
- }
-
- state.transcript_notif_tx = Some(transcript_notif_tx);
-
- gst::debug!(CAT, imp: self, "Task started");
+ gst::debug!(CAT, imp: self, "Tasks Started");
Ok(())
}
- fn stop_task(&self) {
- gst::debug!(CAT, imp: self, "Stopping task");
-
- let _ = self.srcpad.stop_task();
+ fn stop_tasks(&self, state: &mut State) {
+ gst::debug!(CAT, imp: self, "Stopping tasks");
- let mut state = self.state.lock().unwrap();
-
- if let Some(task_abort_handle) = state.task_abort_handle.take() {
- task_abort_handle.abort();
+ if self.static_srcpad.is_linked() {
+ self.static_srcpad.imp().stop_task();
}
- if let Some(ws_loop_handle) = state.ws_loop_handle.take() {
- ws_loop_handle.abort();
+ for pad in state.srcpads.iter() {
+ pad.imp().stop_task();
}
- state.transcript_notif_tx = None;
+ // Terminate the audio buffer stream
state.buffer_tx = None;
- gst::debug!(CAT, imp: self, "Task Stopped");
- }
-
- fn stop_ws_loop(&self) {
- let mut state = self.state.lock().unwrap();
-
- if let Some(ws_loop_handle) = state.ws_loop_handle.take() {
- ws_loop_handle.abort();
- }
-
- state.buffer_tx = None;
- }
-
- fn src_activatemode(
- &self,
- _pad: &gst::Pad,
- _mode: gst::PadMode,
- active: bool,
- ) -> Result<(), gst::LoggableError> {
- if active {
- self.start_task()?;
- } else {
- self.stop_task();
+ if let Some(transcriber_loop_handle) = state.transcriber_loop_handle.take() {
+ transcriber_loop_handle.abort();
}
- Ok(())
- }
-
- fn src_query(&self, pad: &gst::Pad, query: &mut gst::QueryRef) -> bool {
- gst::log!(CAT, obj: pad, "Handling query {query:?}");
-
- use gst::QueryViewMut::*;
- match query.view_mut() {
- Latency(q) => {
- let mut peer_query = gst::query::Latency::new();
-
- let ret = self.sinkpad.peer_query(&mut peer_query);
-
- if ret {
- let (_, min, _) = peer_query.result();
- let our_latency = self.settings.lock().unwrap().latency;
- q.set(true, our_latency + min, gst::ClockTime::NONE);
- }
- ret
- }
- Position(q) => {
- if q.format() == gst::Format::Time {
- let state = self.state.lock().unwrap();
- q.set(
- state
- .out_segment
- .to_stream_time(state.out_segment.position()),
- );
- true
- } else {
- false
- }
- }
- _ => gst::Pad::query_default(pad, Some(&*self.obj()), query),
- }
+ gst::debug!(CAT, imp: self, "Tasks Stopped");
}
fn sink_event(&self, pad: &gst::Pad, event: gst::Event) -> bool {
@@ -665,14 +184,15 @@ impl Transcriber {
use gst::EventView::*;
match event.view() {
Eos(_) => {
- self.stop_ws_loop();
+ // Terminate the audio buffer stream
+ self.state.lock().unwrap().buffer_tx = None;
true
}
FlushStart(_) => {
gst::info!(CAT, imp: self, "Received flush start, disconnecting");
let ret = gst::Pad::event_default(pad, Some(&*self.obj()), event);
- self.stop_task();
+ self.stop_tasks(&mut self.state.lock().unwrap());
ret
}
@@ -680,9 +200,10 @@ impl Transcriber {
gst::info!(CAT, imp: self, "Received flush stop, restarting task");
if gst::Pad::event_default(pad, Some(&*self.obj()), event) {
- match self.start_task() {
+ let state = self.state.lock().unwrap();
+ match self.start_srcpad_tasks(&state) {
Err(err) => {
- gst::error!(CAT, imp: self, "Failed to start srcpad task: {err}");
+ gst::error!(CAT, imp: self, "Failed to start srcpad tasks: {err}");
false
}
Ok(_) => true,
@@ -692,22 +213,17 @@ impl Transcriber {
}
}
Segment(e) => {
- let segment = match e.segment().clone().downcast::<gst::ClockTime>() {
- Err(segment) => {
- gst::element_imp_error!(
- self,
- gst::StreamError::Format,
- ["Only Time segments supported, got {:?}", segment.format()]
- );
- return false;
- }
- Ok(segment) => segment,
+ let format = e.segment().format();
+ if format != gst::Format::Time {
+ gst::element_imp_error!(
+ self,
+ gst::StreamError::Format,
+ ["Only Time segments supported, got {format:?}"]
+ );
+ return false;
};
- let mut state = self.state.lock().unwrap();
-
- state.in_segment = segment;
- state.seqnum = e.seqnum();
+ self.state.lock().unwrap().seqnum = e.seqnum();
true
}
@@ -728,10 +244,7 @@ impl Transcriber {
) -> Result<gst::FlowSuccess, gst::FlowError> {
gst::log!(CAT, obj: pad, "Handling {buffer:?}");
- self.ensure_connection().map_err(|err| {
- gst::element_imp_error!(self, gst::StreamError::Failed, ["Streaming failed: {err}"]);
- gst::FlowError::Error
- })?;
+ self.ensure_connection();
let Some(mut buffer_tx) = self.state.lock().unwrap().buffer_tx.take() else {
gst::log!(CAT, obj: pad, "Flushing");
@@ -748,132 +261,93 @@ impl Transcriber {
Ok(gst::FlowSuccess::Ok)
}
- fn ensure_connection(&self) -> Result<(), gst::ErrorMessage> {
- enum ConfigStatus {
- Ready(aws_config::SdkConfig),
- NotReady {
- access_key: Option<String>,
- secret_access_key: Option<String>,
- session_token: Option<String>,
- },
- }
-
- let (config_status, transcription_settings, lateness, transcript_notif_tx);
- {
- let mut state = self.state.lock().unwrap();
-
- if let Some(ref ws_loop_handle) = state.ws_loop_handle {
- if ws_loop_handle.is_finished() {
- state.ws_loop_handle = None;
-
- const ERR: &str = "ws loop terminated unexpectedly";
- gst::error!(CAT, imp: self, "{ERR}");
- return Err(gst::error_msg!(gst::LibraryError::Failed, ["{ERR}"]));
- }
+ fn ensure_connection(&self) {
+ let mut state = self.state.lock().unwrap();
- return Ok(());
- }
+ if state.buffer_tx.is_some() {
+ return;
+ }
- transcript_notif_tx = state
- .transcript_notif_tx
- .take()
- .expect("attempting to spawn the ws loop, but the srcpad task hasn't been started");
+ let settings = self.settings.lock().unwrap();
- let settings = self.settings.lock().unwrap();
+ let in_caps = self.sinkpad.current_caps().unwrap();
+ let s = in_caps.structure(0).unwrap();
+ let sample_rate = s.get::<i32>("rate").unwrap();
- lateness = settings.lateness;
- if settings.latency + lateness <= 2 * GRANULARITY {
- const ERR: &str = "latency + lateness must be greater than 200 milliseconds";
- gst::error!(CAT, imp: self, "{ERR}");
- return Err(gst::error_msg!(gst::LibraryError::Settings, ["{ERR}"]));
- }
+ let transcription_settings = TranscriptionSettings::from(&settings, sample_rate);
- let in_caps = self.sinkpad.current_caps().unwrap();
- let s = in_caps.structure(0).unwrap();
- let sample_rate = s.get::<i32>("rate").unwrap();
+ let (buffer_tx, buffer_rx) = mpsc::channel(1);
- transcription_settings = TranscriptionSettings::from(&settings, sample_rate);
+ let transcriber_loop = TranscriberLoop::new(
+ self,
+ transcription_settings,
+ settings.lateness,
+ buffer_rx,
+ self.transcript_event_tx.clone(),
+ );
+ let transcriber_loop_handle = RUNTIME.spawn(transcriber_loop.run());
- config_status = if let Some(aws_config) = state.aws_config.take() {
- ConfigStatus::Ready(aws_config)
- } else {
- ConfigStatus::NotReady {
- access_key: settings.access_key.to_owned(),
- secret_access_key: settings.secret_access_key.to_owned(),
- session_token: settings.session_token.to_owned(),
- }
- };
- };
+ state.transcriber_loop_handle = Some(transcriber_loop_handle);
+ state.buffer_tx = Some(buffer_tx);
+ }
- let aws_config = match config_status {
- ConfigStatus::Ready(aws_config) => aws_config,
- ConfigStatus::NotReady {
- access_key,
- secret_access_key,
- session_token,
- } => {
- gst::info!(CAT, imp: self, "Loading aws config...");
- let _enter_guard = RUNTIME.enter();
-
- let config_loader = match (access_key, secret_access_key) {
- (Some(key), Some(secret_key)) => {
- gst::debug!(CAT, imp: self, "Using settings credentials");
- aws_config::ConfigLoader::default().credentials_provider(
- aws_transcribe::Credentials::new(
- key,
- secret_key,
- session_token,
- None,
- "translate",
- ),
- )
- }
- _ => {
- gst::debug!(CAT, imp: self, "Attempting to get credentials from env...");
- aws_config::from_env()
- }
- };
+ fn prepare(&self) -> Result<(), gst::ErrorMessage> {
+ gst::debug!(CAT, imp: self, "Preparing");
- let config_loader = config_loader.region(
- aws_config::meta::region::RegionProviderChain::default_provider()
- .or_else(DEFAULT_TRANSCRIBER_REGION),
- );
- let config = futures::executor::block_on(config_loader.load());
- gst::debug!(CAT, imp: self, "Using region {}", config.region().unwrap());
+ let (access_key, secret_access_key, session_token);
+ {
+ let settings = self.settings.lock().unwrap();
+ access_key = settings.access_key.to_owned();
+ secret_access_key = settings.secret_access_key.to_owned();
+ session_token = settings.session_token.to_owned();
+ }
- config
+ gst::info!(CAT, imp: self, "Loading aws config...");
+ let _enter_guard = RUNTIME.enter();
+
+ let config_loader = match (access_key, secret_access_key) {
+ (Some(key), Some(secret_key)) => {
+ gst::debug!(CAT, imp: self, "Using settings credentials");
+ aws_config::ConfigLoader::default().credentials_provider(
+ aws_transcribe::Credentials::new(
+ key,
+ secret_key,
+ session_token,
+ None,
+ "translate",
+ ),
+ )
+ }
+ _ => {
+ gst::debug!(CAT, imp: self, "Attempting to get credentials from env...");
+ aws_config::from_env()
}
};
- let mut state = self.state.lock().unwrap();
+ let config_loader = config_loader.region(
+ aws_config::meta::region::RegionProviderChain::default_provider()
+ .or_else(DEFAULT_TRANSCRIBER_REGION),
+ );
- let (buffer_tx, buffer_rx) = mpsc::channel(1);
+ let config = futures::executor::block_on(config_loader.load());
+ gst::debug!(CAT, imp: self, "Using region {}", config.region().unwrap());
- let ws_loop_ctx = TranscriberLoop::new(
- self,
- &aws_config,
- transcription_settings,
- lateness,
- buffer_rx,
- transcript_notif_tx,
- );
- let ws_loop_handle = RUNTIME.spawn(ws_loop_ctx.run());
+ *self.aws_config.lock().unwrap() = Some(config);
- state.aws_config = Some(aws_config);
- state.ws_loop_handle = Some(ws_loop_handle);
- state.buffer_tx = Some(buffer_tx);
+ gst::debug!(CAT, imp: self, "Prepared");
Ok(())
}
fn disconnect(&self) {
gst::info!(CAT, imp: self, "Unpreparing");
+ let mut state = self.state.lock().unwrap();
- self.stop_task();
-
- // Also resets discont to true
- *self.state.lock().unwrap() = State::default();
+ self.stop_tasks(&mut state);
+ for pad in state.srcpads.iter() {
+ pad.imp().set_discont();
+ }
gst::info!(CAT, imp: self, "Unprepared");
}
}
@@ -883,6 +357,7 @@ impl ObjectSubclass for Transcriber {
const NAME: &'static str = "GstAwsTranscriber";
type Type = super::Transcriber;
type ParentType = gst::Element;
+ type Interfaces = (gst::ChildProxy,);
fn with_class(klass: &Self::Class) -> Self {
let templ = klass.pad_template("sink").unwrap();
@@ -904,36 +379,42 @@ impl ObjectSubclass for Transcriber {
.build();
let templ = klass.pad_template("src").unwrap();
- let srcpad = gst::Pad::builder_with_template(&templ, Some("src"))
- .activatemode_function(|pad, parent, mode, active| {
- Transcriber::catch_panic_pad_function(
- parent,
- || {
- Err(gst::loggable_error!(
- CAT,
- "Panic activating src pad with mode"
- ))
- },
- |transcriber| transcriber.src_activatemode(pad, mode, active),
- )
- })
- .query_function(|pad, parent, query| {
- Transcriber::catch_panic_pad_function(
- parent,
- || false,
- |transcriber| transcriber.src_query(pad, query),
- )
- })
- .flags(gst::PadFlags::FIXED_CAPS)
- .build();
+ let static_srcpad =
+ gst::PadBuilder::<super::TranslationSrcPad>::from_template(&templ, Some("src"))
+ .activatemode_function(|pad, parent, mode, active| {
+ Transcriber::catch_panic_pad_function(
+ parent,
+ || {
+ Err(gst::loggable_error!(
+ CAT,
+ "Panic activating TranslationSrcPad"
+ ))
+ },
+ |elem| TranslationSrcPad::activatemode(elem, pad, mode, active),
+ )
+ })
+ .query_function(|pad, parent, query| {
+ Transcriber::catch_panic_pad_function(
+ parent,
+ || false,
+ |elem| TranslationSrcPad::src_query(elem, pad, query),
+ )
+ })
+ .flags(gst::PadFlags::FIXED_CAPS)
+ .build();
- let settings = Mutex::new(Settings::default());
+ // Setting the channel capacity so that a TranslationSrcPad that would lag
+ // behind for some reasons get a chance to catch-up without loosing items.
+ // Receiver will be created by subscribing to sender later.
+ let (transcript_event_tx, _) = broadcast::channel(128);
Self {
- srcpad,
+ static_srcpad,
sinkpad,
- settings,
- state: Mutex::new(State::default()),
+ settings: Default::default(),
+ state: Default::default(),
+ aws_config: Default::default(),
+ transcript_event_tx,
}
}
}
@@ -947,13 +428,38 @@ impl ObjectImpl for Transcriber {
.blurb("The Language of the Stream, see \
<https://docs.aws.amazon.com/transcribe/latest/dg/how-streaming-transcription.html> \
for an up to date list of allowed languages")
- .default_value(Some(DEFAULT_LANGUAGE_CODE))
+ .default_value(Some(DEFAULT_INPUT_LANG_CODE))
.mutable_ready()
.build(),
- glib::ParamSpecUInt::builder("latency")
+ glib::ParamSpecUInt::builder(DEPRECATED_LATENCY_PROPERTY)
.nick("Latency")
+ .blurb("Amount of milliseconds to allow AWS transcribe (Deprecated. Use transcribe-latency)")
+ .default_value(DEFAULT_TRANSCRIBE_LATENCY.mseconds() as u32)
+ .mutable_ready()
+ .deprecated()
+ .build(),
+ glib::ParamSpecUInt::builder(TRANSCRIBE_LATENCY_PROPERTY)
+ .nick("AWS Transcribe Latency")
.blurb("Amount of milliseconds to allow AWS transcribe")
- .default_value(DEFAULT_LATENCY.mseconds() as u32)
+ .default_value(DEFAULT_TRANSCRIBE_LATENCY.mseconds() as u32)
+ .mutable_ready()
+ .build(),
+ glib::ParamSpecUInt::builder(TRANSLATE_LATENCY_PROPERTY)
+ .nick("AWS Translate Latency")
+ .blurb(concat!(
+ "Amount of milliseconds to allow AWS translate ",
+ "(ignored if the input and output languages are the same)",
+ ))
+ .default_value(DEFAULT_TRANSLATE_LATENCY.mseconds() as u32)
+ .mutable_ready()
+ .build(),
+ glib::ParamSpecUInt::builder(TRANSCRIPT_LOOKAHEAD_PROPERTY)
+ .nick("Transcript chunk")
+ .blurb(concat!(
+ "Maximum duration in milliseconds of transcript to lookahead ",
+ "before sending to translation when no separator was encountered",
+ ))
+ .default_value(DEFAULT_TRANSCRIPT_LOOKAHEAD.mseconds() as u32)
.mutable_ready()
.build(),
glib::ParamSpecUInt::builder("lateness")
@@ -1017,7 +523,7 @@ impl ObjectImpl for Transcriber {
let obj = self.obj();
obj.add_pad(&self.sinkpad).unwrap();
- obj.add_pad(&self.srcpad).unwrap();
+ obj.add_pad(&self.static_srcpad).unwrap();
obj.set_element_flags(gst::ElementFlags::PROVIDE_CLOCK | gst::ElementFlags::REQUIRE_CLOCK);
}
@@ -1027,12 +533,26 @@ impl ObjectImpl for Transcriber {
let mut settings = self.settings.lock().unwrap();
settings.language_code = value.get().expect("type checked upstream");
}
- "latency" => {
+ DEPRECATED_LATENCY_PROPERTY => {
let mut settings = self.settings.lock().unwrap();
- settings.latency = gst::ClockTime::from_mseconds(
+ settings.transcribe_latency = gst::ClockTime::from_mseconds(
value.get::<u32>().expect("type checked upstream").into(),
);
}
+ TRANSCRIBE_LATENCY_PROPERTY => {
+ let mut settings = self.settings.lock().unwrap();
+ settings.transcribe_latency = gst::ClockTime::from_mseconds(
+ value.get::<u32>().expect("type checked upstream").into(),
+ );
+ }
+ TRANSLATE_LATENCY_PROPERTY => {
+ self.settings.lock().unwrap().translate_latency =
+ gst::ClockTime::from_mseconds(value.get::<u32>().unwrap().into());
+ }
+ TRANSCRIPT_LOOKAHEAD_PROPERTY => {
+ self.settings.lock().unwrap().transcript_lookahead =
+ gst::ClockTime::from_mseconds(value.get::<u32>().unwrap().into());
+ }
"lateness" => {
let mut settings = self.settings.lock().unwrap();
settings.lateness = gst::ClockTime::from_mseconds(
@@ -1085,10 +605,24 @@ impl ObjectImpl for Transcriber {
let settings = self.settings.lock().unwrap();
settings.language_code.to_value()
}
- "latency" => {
+ DEPRECATED_LATENCY_PROPERTY => {
let settings = self.settings.lock().unwrap();
- (settings.latency.mseconds() as u32).to_value()
+ (settings.transcribe_latency.mseconds() as u32).to_value()
}
+ TRANSCRIBE_LATENCY_PROPERTY => {
+ let settings = self.settings.lock().unwrap();
+ (settings.transcribe_latency.mseconds() as u32).to_value()
+ }
+ TRANSLATE_LATENCY_PROPERTY => {
+ (self.settings.lock().unwrap().translate_latency.mseconds() as u32).to_value()
+ }
+ TRANSCRIPT_LOOKAHEAD_PROPERTY => (self
+ .settings
+ .lock()
+ .unwrap()
+ .transcript_lookahead
+ .mseconds() as u32)
+ .to_value(),
"lateness" => {
let settings = self.settings.lock().unwrap();
(settings.lateness.mseconds() as u32).to_value()
@@ -1151,11 +685,20 @@ impl ElementImpl for Transcriber {
let src_caps = gst::Caps::builder("text/x-raw")
.field("format", "utf8")
.build();
- let src_pad_template = gst::PadTemplate::new(
+ let src_pad_template = gst::PadTemplate::with_gtype(
"src",
gst::PadDirection::Src,
gst::PadPresence::Always,
&src_caps,
+ super::TranslationSrcPad::static_type(),
+ )
+ .unwrap();
+ let req_src_pad_template = gst::PadTemplate::with_gtype(
+ "translation_src_%u",
+ gst::PadDirection::Src,
+ gst::PadPresence::Request,
+ &src_caps,
+ super::TranslationSrcPad::static_type(),
)
.unwrap();
@@ -1172,7 +715,7 @@ impl ElementImpl for Transcriber {
)
.unwrap();
- vec![src_pad_template, sink_pad_template]
+ vec![src_pad_template, req_src_pad_template, sink_pad_template]
});
PAD_TEMPLATES.as_ref()
@@ -1184,6 +727,13 @@ impl ElementImpl for Transcriber {
) -> Result<gst::StateChangeSuccess, gst::StateChangeError> {
gst::info!(CAT, imp: self, "Changing state {transition:?}");
+ if let gst::StateChange::NullToReady = transition {
+ self.prepare().map_err(|err| {
+ self.post_error_message(err);
+ gst::StateChangeError
+ })?;
+ }
+
let mut success = self.parent_change_state(transition)?;
match transition {
@@ -1202,7 +752,848 @@ impl ElementImpl for Transcriber {
Ok(success)
}
+ fn request_new_pad(
+ &self,
+ templ: &gst::PadTemplate,
+ _name: Option<&str>,
+ _caps: Option<&gst::Caps>,
+ ) -> Option<gst::Pad> {
+ let mut state = self.state.lock().unwrap();
+
+ let pad = gst::PadBuilder::<super::TranslationSrcPad>::from_template(
+ templ,
+ Some(format!("translation_src_{}", state.pad_serial).as_str()),
+ )
+ .activatemode_function(|pad, parent, mode, active| {
+ Transcriber::catch_panic_pad_function(
+ parent,
+ || {
+ Err(gst::loggable_error!(
+ CAT,
+ "Panic activating TranslationSrcPad"
+ ))
+ },
+ |elem| TranslationSrcPad::activatemode(elem, pad, mode, active),
+ )
+ })
+ .query_function(|pad, parent, query| {
+ Transcriber::catch_panic_pad_function(
+ parent,
+ || false,
+ |elem| TranslationSrcPad::src_query(elem, pad, query),
+ )
+ })
+ .flags(gst::PadFlags::FIXED_CAPS)
+ .build();
+
+ state.srcpads.insert(pad.clone());
+
+ state.pad_serial += 1;
+ drop(state);
+
+ self.obj().add_pad(&pad).unwrap();
+
+ let _ = self
+ .obj()
+ .post_message(gst::message::Latency::builder().src(&*self.obj()).build());
+
+ self.obj().child_added(&pad, &pad.name());
+ Some(pad.upcast())
+ }
+
+ fn release_pad(&self, pad: &gst::Pad) {
+ pad.set_active(false).unwrap();
+ self.obj().remove_pad(pad).unwrap();
+
+ self.obj().child_removed(pad, &pad.name());
+ let _ = self
+ .obj()
+ .post_message(gst::message::Latency::builder().src(&*self.obj()).build());
+ }
+
fn provide_clock(&self) -> Option<gst::Clock> {
Some(gst::SystemClock::obtain())
}
}
+
+// Implementation of gst::ChildProxy virtual methods.
+//
+// This allows accessing the pads and their properties from e.g. gst-launch.
+impl ChildProxyImpl for Transcriber {
+ fn children_count(&self) -> u32 {
+ let object = self.obj();
+ object.num_pads() as u32
+ }
+
+ fn child_by_name(&self, name: &str) -> Option<glib::Object> {
+ let object = self.obj();
+ object
+ .pads()
+ .into_iter()
+ .find(|p| p.name() == name)
+ .map(|p| p.upcast())
+ }
+
+ fn child_by_index(&self, index: u32) -> Option<glib::Object> {
+ let object = self.obj();
+ object
+ .pads()
+ .into_iter()
+ .nth(index as usize)
+ .map(|p| p.upcast())
+ }
+}
+struct TranslationPadTask {
+ pad: glib::subclass::ObjectImplRef<TranslationSrcPad>,
+ elem: super::Transcriber,
+ transcript_event_rx: broadcast::Receiver<TranscriptEvent>,
+ needs_translate: bool,
+ translation_queue: TranslationQueue,
+ translation_loop_handle: Option<task::JoinHandle<Result<(), gst::ErrorMessage>>>,
+ to_translation_tx: Option<mpsc::Sender<TranscriptItem>>,
+ from_translation_rx: Option<mpsc::Receiver<TranslatedItem>>,
+ translate_latency: gst::ClockTime,
+ transcript_lookahead: gst::ClockTime,
+ send_events: bool,
+ translated_items: VecDeque<TranslatedItem>,
+ our_latency: gst::ClockTime,
+ seqnum: gst::Seqnum,
+ send_eos: bool,
+ pending_translations: usize,
+ start_time: Option<gst::ClockTime>,
+}
+
+impl TranslationPadTask {
+ fn try_new(
+ pad: &TranslationSrcPad,
+ elem: super::Transcriber,
+ transcript_event_rx: broadcast::Receiver<TranscriptEvent>,
+ ) -> Result<TranslationPadTask, gst::ErrorMessage> {
+ let mut this = TranslationPadTask {
+ pad: pad.ref_counted(),
+ elem,
+ transcript_event_rx,
+ needs_translate: false,
+ translation_queue: TranslationQueue::default(),
+ translation_loop_handle: None,
+ to_translation_tx: None,
+ from_translation_rx: None,
+ translate_latency: DEFAULT_TRANSLATE_LATENCY,
+ transcript_lookahead: DEFAULT_TRANSCRIPT_LOOKAHEAD,
+ send_events: true,
+ translated_items: VecDeque::new(),
+ our_latency: DEFAULT_TRANSCRIBE_LATENCY,
+ seqnum: gst::Seqnum::next(),
+ send_eos: false,
+ pending_translations: 0,
+ start_time: None,
+ };
+
+ let _enter_guard = RUNTIME.enter();
+ futures::executor::block_on(this.init_translate())?;
+
+ Ok(this)
+ }
+}
+
+impl Drop for TranslationPadTask {
+ fn drop(&mut self) {
+ if let Some(translation_loop_handle) = self.translation_loop_handle.take() {
+ translation_loop_handle.abort();
+ }
+ }
+}
+
+impl TranslationPadTask {
+ async fn run_iter(&mut self) -> Result<(), gst::ErrorMessage> {
+ self.ensure_init_events()?;
+
+ if self.needs_translate {
+ self.translate_iter().await?;
+ } else {
+ self.passthrough_iter().await?;
+ }
+
+ if !self.dequeue().await {
+ gst::info!(CAT, imp: self.pad, "Failed to dequeue buffer, pausing");
+ let _ = self.pad.obj().pause_task();
+ }
+
+ Ok(())
+ }
+
+ async fn passthrough_iter(&mut self) -> Result<(), gst::ErrorMessage> {
+ // This is to make sure we send items on a timely basis or at least Gap events.
+ let timeout = tokio::time::sleep(GRANULARITY.into()).fuse();
+ futures::pin_mut!(timeout);
+
+ let transcript_event_rx = self.transcript_event_rx.recv().fuse();
+ futures::pin_mut!(transcript_event_rx);
+
+ // `timeout` takes precedence over `transcript_events` reception
+ // because we may need to `dequeue` `items` or push a `Gap` event
+ // before current latency budget is exhausted.
+ futures::select_biased! {
+ _ = timeout => (),
+ items_res = transcript_event_rx => {
+ use TranscriptEvent::*;
+ use broadcast::error::RecvError;
+ match items_res {
+ Ok(Items(transcript_items)) => {
+ for transcript_item in transcript_items.iter() {
+ self.translated_items.push_back(transcript_item.into());
+ }
+ }
+ Ok(Eos) => {
+ gst::debug!(CAT, imp: self.pad, "Got eos");
+ self.send_eos = true;
+ }
+ Err(RecvError::Lagged(nb_msg)) => {
+ gst::warning!(CAT, imp: self.pad, "Missed {nb_msg} transcript sets");
+ }
+ Err(RecvError::Closed) => {
+ gst::debug!(CAT, imp: self.pad, "Transcript chan terminated: setting eos");
+ self.send_eos = true;
+ }
+ }
+ }
+ }
+
+ Ok(())
+ }
+
+ async fn translate_iter(&mut self) -> Result<(), gst::ErrorMessage> {
+ if self
+ .translation_loop_handle
+ .as_ref()
+ .map_or(true, task::JoinHandle::is_finished)
+ {
+ const ERR: &str = "Translation loop is not running";
+ gst::error!(CAT, imp: self.pad, "{ERR}");
+ return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"]));
+ }
+
+ let transcript_items = {
+ // This is to make sure we send items on a timely basis or at least Gap events.
+ let timeout = tokio::time::sleep(GRANULARITY.into()).fuse();
+ futures::pin_mut!(timeout);
+
+ let from_translation_rx = self
+ .from_translation_rx
+ .as_mut()
+ .expect("from_translation chan must be available in translation mode");
+
+ let transcript_event_rx = self.transcript_event_rx.recv().fuse();
+ futures::pin_mut!(transcript_event_rx);
+
+ // `timeout` takes precedence over `transcript_events` reception
+ // because we may need to `dequeue` `items` or push a `Gap` event
+ // before current latency budget is exhausted.
+ futures::select_biased! {
+ _ = timeout => return Ok(()),
+ translated_item = from_translation_rx.next() => {
+ let Some(translated_item) = translated_item else {
+ const ERR: &str = "translation chan terminated";
+ gst::debug!(CAT, imp: self.pad, "{ERR}");
+ return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"]));
+ };
+
+ self.translated_items.push_back(translated_item);
+ self.pending_translations = self.pending_translations.saturating_sub(1);
+
+ return Ok(());
+ }
+ items_res = transcript_event_rx => {
+ use TranscriptEvent::*;
+ use broadcast::error::RecvError;
+ match items_res {
+ Ok(Items(transcript_items)) => transcript_items,
+ Ok(Eos) => {
+ gst::debug!(CAT, imp: self.pad, "Got eos");
+ self.send_eos = true;
+ return Ok(());
+ }
+ Err(RecvError::Lagged(nb_msg)) => {
+ gst::warning!(CAT, imp: self.pad, "Missed {nb_msg} transcript sets");
+ return Ok(());
+ }
+ Err(RecvError::Closed) => {
+ gst::debug!(CAT, imp: self.pad, "Transcript chan terminated: setting eos");
+ self.send_eos = true;
+ return Ok(());
+ }
+ }
+ }
+ }
+ };
+
+ for item in transcript_items.iter() {
+ if let Some(ready_item) = self.translation_queue.push(item) {
+ self.send_for_translation(ready_item).await?;
+ }
+ }
+
+ Ok(())
+ }
+
+ async fn dequeue(&mut self) -> bool {
+ let (now, start_time, mut last_position, mut discont_pending);
+ {
+ let mut pad_state = self.pad.state.lock().unwrap();
+
+ let Some(cur_rt) = self.elem.current_running_time() else {
+ // Wait for the clock to be available
+ return true;
+ };
+ now = cur_rt;
+
+ if self.start_time.is_none() {
+ self.start_time = Some(now);
+ pad_state.out_segment.set_position(now);
+ }
+
+ start_time = self.start_time.unwrap();
+ last_position = pad_state.out_segment.position().unwrap();
+ discont_pending = pad_state.discont_pending;
+ }
+
+ if self.needs_translate && !self.translation_queue.is_empty() {
+ // Maximum delay for an item to be pushed to stream on time
+ // Margin:
+ // - 1 * GRANULARITY: the time it will take before we can check this again,
+ // without running late, in the case of a timeout.
+ // - 2 * GRANULARITY: extra margin to account for additional overheads.
+ // FIXME explaing which ones.
+ let max_delay = self.our_latency.saturating_sub(3 * GRANULARITY);
+
+ // Estimated time of arrival for an item sent to translation now.
+ // (in transcript item ts base)
+ let translation_eta = now + self.translate_latency - start_time;
+
+ let deadline = translation_eta.saturating_sub(max_delay);
+
+ if let Some(ready_item) = self
+ .translation_queue
+ .dequeue(deadline, self.transcript_lookahead)
+ {
+ gst::debug!(
+ CAT,
+ imp: self.pad,
+ "Forcing transcript at pts {} with duration {} to translation",
+ ready_item.pts,
+ ready_item.duration,
+ );
+
+ if self.send_for_translation(ready_item).await.is_err() {
+ return false;
+ }
+ }
+ }
+
+ /* First, check our pending buffers */
+ while let Some(item) = self.translated_items.front() {
+ // Note: items pts start from 0 + lateness
+ gst::trace!(
+ CAT,
+ imp: self.pad,
+ "Checking now {now} if item is ready for dequeuing, PTS {}, threshold {} vs {}",
+ item.pts,
+ item.pts + self.our_latency.saturating_sub(3 * GRANULARITY),
+ now - start_time
+ );
+
+ // Margin:
+ // - 1 * GRANULARITY: the time it will take before we can check this again,
+ // without running late, in the case of a timeout.
+ // - 2 * GRANULARITY: extra margin to account for additional overheads.
+ // FIXME explaing which ones.
+ if item.pts + self.our_latency.saturating_sub(3 * GRANULARITY) < now - start_time {
+ /* Safe unwrap, we know we have an item */
+ let TranslatedItem {
+ pts: item_pts,
+ mut duration,
+ content,
+ } = self.translated_items.pop_front().unwrap();
+
+ let mut pts = start_time + item_pts;
+
+ let mut buf = gst::Buffer::from_mut_slice(content.into_bytes());
+ {
+ let buf = buf.get_mut().unwrap();
+
+ if discont_pending {
+ buf.set_flags(gst::BufferFlags::DISCONT);
+ discont_pending = false;
+ }
+
+ buf.set_pts(pts);
+ buf.set_duration(duration);
+ }
+
+ use std::cmp::Ordering::*;
+ match pts.cmp(&last_position) {
+ Greater => {
+ // The buffer we are about to push starts after the end of
+ // last item previously pushed to the stream.
+ let gap_event = gst::event::Gap::builder(last_position)
+ .duration(pts - last_position)
+ .seqnum(self.seqnum)
+ .build();
+ gst::log!(CAT, imp: self.pad, "Pushing gap: {last_position} -> {pts}");
+ if !self.pad.obj().push_event(gap_event) {
+ return false;
+ }
+ }
+ Less => {
+ // The buffer we are about to push was expected to start
+ // before the end of last item previously pushed to the stream.
+ // => update it to fit in stream.
+ let delta = last_position - pts;
+
+ gst::warning!(
+ CAT,
+ imp: self.pad,
+ "Updating item PTS ({pts} < {last_position}), consider increasing latency",
+ );
+
+ pts = last_position;
+ // FIXME if the resulting duration is zero, we might as well not push it.
+ duration = duration.saturating_sub(delta);
+
+ {
+ let buf_mut = buf.get_mut().unwrap();
+
+ buf_mut.set_pts(pts);
+ buf_mut.set_duration(duration);
+ }
+ }
+ _ => (),
+ }
+
+ last_position = pts + duration;
+
+ gst::debug!(CAT, imp: self.pad, "Pushing buffer: {pts} -> {}", pts + duration);
+
+ if self.pad.obj().push(buf).is_err() {
+ return false;
+ }
+ } else {
+ // Current and subsequent items are not ready to be pushed
+ break;
+ }
+ }
+
+ if self.send_eos
+ && self.pending_translations == 0
+ && self.translated_items.is_empty()
+ && self.translation_queue.is_empty()
+ {
+ /* We're EOS, we can pause and exit early */
+ let _ = self.pad.obj().pause_task();
+
+ gst::info!(CAT, imp: self.pad, "Sending eos");
+ return self
+ .pad
+ .obj()
+ .push_event(gst::event::Eos::builder().seqnum(self.seqnum).build());
+ }
+
+ /* next, push a gap if we're lagging behind the target position */
+ gst::trace!(
+ CAT,
+ imp: self.pad,
+ "Checking now: {now} if we need to push a gap, last_position: {last_position}, threshold: {}",
+ last_position + self.our_latency.saturating_sub(GRANULARITY)
+ );
+
+ if now > last_position + self.our_latency.saturating_sub(GRANULARITY) {
+ // We are running out of latency budget since last time we pushed downstream,
+ // so push a Gap long enough to keep continuity before we dequeue again:
+ // worse case scenario, this is GRANULARITY ms from now.
+ let duration = now - last_position - self.our_latency.saturating_sub(GRANULARITY);
+
+ let gap_event = gst::event::Gap::builder(last_position)
+ .duration(duration)
+ .seqnum(self.seqnum)
+ .build();
+
+ gst::log!(
+ CAT,
+ imp: self.pad,
+ "Pushing gap: {last_position} -> {}",
+ last_position + duration
+ );
+
+ last_position += duration;
+
+ if !self.pad.obj().push_event(gap_event) {
+ return false;
+ }
+ }
+
+ let mut pad_state = self.pad.state.lock().unwrap();
+ pad_state.out_segment.set_position(last_position);
+ pad_state.discont_pending = discont_pending;
+
+ true
+ }
+
+ async fn send_for_translation(
+ &mut self,
+ transcript_item: TranscriptItem,
+ ) -> Result<(), gst::ErrorMessage> {
+ let res = self
+ .to_translation_tx
+ .as_mut()
+ .expect("to_translation chan must be available in translation mode")
+ .send(transcript_item)
+ .await;
+
+ if res.is_err() {
+ const ERR: &str = "to_translation chan terminated";
+ gst::debug!(CAT, imp: self.pad, "{ERR}");
+ return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"]));
+ }
+
+ self.pending_translations += 1;
+
+ Ok(())
+ }
+
+ fn ensure_init_events(&mut self) -> Result<(), gst::ErrorMessage> {
+ if !self.send_events {
+ return Ok(());
+ }
+
+ let mut events = vec![];
+
+ {
+ let elem_imp = self.elem.imp();
+ let elem_state = elem_imp.state.lock().unwrap();
+
+ let mut pad_state = self.pad.state.lock().unwrap();
+
+ self.seqnum = elem_state.seqnum;
+ pad_state.out_segment = Default::default();
+
+ events.push(
+ gst::event::StreamStart::builder("transcription")
+ .seqnum(self.seqnum)
+ .build(),
+ );
+
+ let caps = gst::Caps::builder("text/x-raw")
+ .field("format", "utf8")
+ .build();
+ events.push(gst::event::Caps::builder(&caps).seqnum(self.seqnum).build());
+
+ events.push(
+ gst::event::Segment::builder(&pad_state.out_segment)
+ .seqnum(self.seqnum)
+ .build(),
+ );
+ }
+
+ for event in events.drain(..) {
+ gst::info!(CAT, imp: self.pad, "Sending {event:?}");
+ if !self.pad.obj().push_event(event) {
+ const ERR: &str = "Failed to send initial";
+ gst::error!(CAT, imp: self.pad, "{ERR}");
+ return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"]));
+ }
+ }
+
+ self.send_events = false;
+
+ Ok(())
+ }
+}
+
+impl TranslationPadTask {
+ async fn init_translate(&mut self) -> Result<(), gst::ErrorMessage> {
+ let mut translation_loop = None;
+
+ {
+ let elem_imp = self.elem.imp();
+ let elem_settings = elem_imp.settings.lock().unwrap();
+
+ let pad_settings = self.pad.settings.lock().unwrap();
+
+ self.our_latency = TranslationSrcPad::our_latency(&elem_settings, &pad_settings);
+ if self.our_latency + elem_settings.lateness <= 2 * GRANULARITY {
+ let err = format!(
+ "total latency + lateness must be greater than {}",
+ 2 * GRANULARITY
+ );
+ gst::error!(CAT, imp: self.pad, "{err}");
+ return Err(gst::error_msg!(gst::LibraryError::Settings, ["{err}"]));
+ }
+
+ self.translate_latency = elem_settings.translate_latency;
+ self.transcript_lookahead = elem_settings.transcript_lookahead;
+
+ self.needs_translate = TranslationSrcPad::needs_translation(
+ &elem_settings.language_code,
+ pad_settings.language_code.as_deref(),
+ );
+
+ if self.needs_translate {
+ let (to_translation_tx, to_translation_rx) = mpsc::channel(64);
+ let (from_translation_tx, from_translation_rx) = mpsc::channel(64);
+
+ translation_loop = Some(TranslationLoop::new(
+ elem_imp,
+ &self.pad,
+ &elem_settings.language_code,
+ pad_settings.language_code.as_deref().unwrap(),
+ to_translation_rx,
+ from_translation_tx,
+ ));
+
+ self.to_translation_tx = Some(to_translation_tx);
+ self.from_translation_rx = Some(from_translation_rx);
+ }
+ }
+
+ if let Some(translation_loop) = translation_loop {
+ translation_loop.check_language().await?;
+ self.translation_loop_handle = Some(RUNTIME.spawn(translation_loop.run()));
+ }
+
+ Ok(())
+ }
+}
+
+#[derive(Debug)]
+struct TranslationPadState {
+ discont_pending: bool,
+ out_segment: gst::FormattedSegment<gst::ClockTime>,
+ task_abort_handle: Option<AbortHandle>,
+}
+
+impl Default for TranslationPadState {
+ fn default() -> TranslationPadState {
+ TranslationPadState {
+ discont_pending: true,
+ out_segment: Default::default(),
+ task_abort_handle: None,
+ }
+ }
+}
+
+#[derive(Debug, Default, Clone)]
+struct TranslationPadSettings {
+ language_code: Option<String>,
+}
+
+#[derive(Debug, Default)]
+pub struct TranslationSrcPad {
+ state: Mutex<TranslationPadState>,
+ settings: Mutex<TranslationPadSettings>,
+}
+
+impl TranslationSrcPad {
+ fn start_task(&self) -> Result<(), gst::LoggableError> {
+ gst::debug!(CAT, imp: self, "Starting task");
+
+ let elem = self.parent();
+ let transcript_event_rx = elem.imp().transcript_event_tx.subscribe();
+ let mut pad_task = TranslationPadTask::try_new(self, elem, transcript_event_rx)
+ .map_err(|err| gst::loggable_error!(CAT, format!("Failed to start pad task {err}")))?;
+
+ let imp = self.ref_counted();
+ let res = self.obj().start_task(move || {
+ let (abortable_task_iter, abort_handle) = future::abortable(pad_task.run_iter());
+ imp.state.lock().unwrap().task_abort_handle = Some(abort_handle);
+
+ let _enter = RUNTIME.enter();
+ match futures::executor::block_on(abortable_task_iter) {
+ Ok(Ok(())) => (),
+ Ok(Err(err)) => {
+ // Don't bring down the whole element if this Pad fails
+ // FIXME is there a way to mark the Pad in error though?
+ gst::info!(CAT, imp: imp, "Pausing task due to: {err}");
+ let _ = imp.obj().pause_task();
+ }
+ Err(_) => gst::debug!(CAT, imp: imp, "task iter aborted"),
+ }
+ });
+
+ if res.is_err() {
+ return Err(gst::loggable_error!(CAT, "Failed to start pad task"));
+ }
+
+ gst::debug!(CAT, imp: self, "Task started");
+
+ Ok(())
+ }
+
+ fn stop_task(&self) {
+ gst::debug!(CAT, imp: self, "Stopping task");
+
+ // See also the note in `start_task()`:
+ // 1. Mark the task as stopped so no further iteration is executed.
+ let _ = self.obj().stop_task();
+
+ // 2. Abort the task iteration if the Future is pending.
+ if let Some(task_abort_handle) = self.state.lock().unwrap().task_abort_handle.take() {
+ task_abort_handle.abort();
+ }
+
+ gst::debug!(CAT, imp: self, "Task stopped");
+ }
+
+ fn set_discont(&self) {
+ self.state.lock().unwrap().discont_pending = true;
+ }
+
+ #[inline]
+ fn needs_translation(input_lang: &str, output_lang: Option<&str>) -> bool {
+ output_lang.map_or(false, |other| {
+ !input_lang.eq_ignore_ascii_case(other.as_ref())
+ })
+ }
+
+ #[inline]
+ fn our_latency(
+ elem_settings: &Settings,
+ pad_settings: &TranslationPadSettings,
+ ) -> gst::ClockTime {
+ if Self::needs_translation(
+ &elem_settings.language_code,
+ pad_settings.language_code.as_deref(),
+ ) {
+ elem_settings.transcribe_latency
+ + elem_settings.transcript_lookahead
+ + elem_settings.translate_latency
+ } else {
+ elem_settings.transcribe_latency
+ }
+ }
+
+ #[track_caller]
+ fn parent(&self) -> super::Transcriber {
+ self.obj()
+ .parent()
+ .map(|elem_obj| {
+ elem_obj
+ .downcast::<super::Transcriber>()
+ .expect("Wrong Element type")
+ })
+ .expect("Pad should have a parent at this stage")
+ }
+}
+
+impl TranslationSrcPad {
+ #[track_caller]
+ pub fn activatemode(
+ _elem: &Transcriber,
+ pad: &super::TranslationSrcPad,
+ _mode: gst::PadMode,
+ active: bool,
+ ) -> Result<(), gst::LoggableError> {
+ if active {
+ pad.imp().start_task()?;
+ } else {
+ pad.imp().stop_task();
+ }
+
+ Ok(())
+ }
+
+ pub fn src_query(
+ elem: &Transcriber,
+ pad: &super::TranslationSrcPad,
+ query: &mut gst::QueryRef,
+ ) -> bool {
+ gst::log!(CAT, obj: pad, "Handling query {query:?}");
+
+ use gst::QueryViewMut::*;
+ match query.view_mut() {
+ Latency(q) => {
+ let mut peer_query = gst::query::Latency::new();
+
+ let ret = elem.sinkpad.peer_query(&mut peer_query);
+
+ if ret {
+ let (_, min, _) = peer_query.result();
+
+ let our_latency = {
+ let elem_settings = elem.settings.lock().unwrap();
+ let pad_settings = pad.imp().settings.lock().unwrap();
+
+ Self::our_latency(&elem_settings, &pad_settings)
+ };
+
+ gst::info!(CAT, obj: pad, "Our latency {our_latency}");
+ q.set(true, our_latency + min, gst::ClockTime::NONE);
+ }
+ ret
+ }
+ Position(q) => {
+ if q.format() == gst::Format::Time {
+ let stream_time = {
+ let state = pad.imp().state.lock().unwrap();
+ state
+ .out_segment
+ .to_stream_time(state.out_segment.position())
+ };
+
+ let Some(stream_time) = stream_time else { return false };
+ q.set(stream_time);
+
+ true
+ } else {
+ false
+ }
+ }
+ _ => gst::Pad::query_default(pad, Some(pad), query),
+ }
+ }
+}
+
+#[glib::object_subclass]
+impl ObjectSubclass for TranslationSrcPad {
+ const NAME: &'static str = "GstTranslationSrcPad";
+ type Type = super::TranslationSrcPad;
+ type ParentType = gst::Pad;
+
+ fn new() -> Self {
+ Default::default()
+ }
+}
+
+impl ObjectImpl for TranslationSrcPad {
+ fn properties() -> &'static [glib::ParamSpec] {
+ static PROPERTIES: Lazy<Vec<glib::ParamSpec>> = Lazy::new(|| {
+ vec![glib::ParamSpecString::builder(OUTPUT_LANG_CODE_PROPERTY)
+ .nick("Language Code")
+ .blurb("The Language the Stream must be translated to")
+ .default_value(DEFAULT_OUTPUT_LANG_CODE)
+ .mutable_ready()
+ .build()]
+ });
+
+ PROPERTIES.as_ref()
+ }
+
+ fn set_property(&self, _id: usize, value: &glib::Value, pspec: &glib::ParamSpec) {
+ match pspec.name() {
+ OUTPUT_LANG_CODE_PROPERTY => {
+ self.settings.lock().unwrap().language_code = value.get().unwrap()
+ }
+ _ => unimplemented!(),
+ }
+ }
+
+ fn property(&self, _id: usize, pspec: &glib::ParamSpec) -> glib::Value {
+ match pspec.name() {
+ OUTPUT_LANG_CODE_PROPERTY => self.settings.lock().unwrap().language_code.to_value(),
+ _ => unimplemented!(),
+ }
+ }
+}
+
+impl GstObjectImpl for TranslationSrcPad {}
+
+impl PadImpl for TranslationSrcPad {}
diff --git a/net/aws/src/transcriber/mod.rs b/net/aws/src/transcriber/mod.rs
index 69ac6059..eb2a28f7 100644
--- a/net/aws/src/transcriber/mod.rs
+++ b/net/aws/src/transcriber/mod.rs
@@ -10,6 +10,18 @@ use gst::glib;
use gst::prelude::*;
mod imp;
+mod transcribe;
+mod translate;
+
+use once_cell::sync::Lazy;
+
+static CAT: Lazy<gst::DebugCategory> = Lazy::new(|| {
+ gst::DebugCategory::new(
+ "awstranscribe",
+ gst::DebugColorFlags::empty(),
+ Some("AWS Transcribe element"),
+ )
+});
use aws_sdk_transcribestreaming::model::{PartialResultsStability, VocabularyFilterMethod};
@@ -68,7 +80,11 @@ impl From<AwsTranscriberVocabularyFilterMethod> for VocabularyFilterMethod {
}
glib::wrapper! {
- pub struct Transcriber(ObjectSubclass<imp::Transcriber>) @extends gst::Element, gst::Object;
+ pub struct Transcriber(ObjectSubclass<imp::Transcriber>) @extends gst::Element, gst::Object, @implements gst::ChildProxy;
+}
+
+glib::wrapper! {
+ pub struct TranslationSrcPad(ObjectSubclass<imp::TranslationSrcPad>) @extends gst::Pad, gst::Object;
}
pub fn register(plugin: &gst::Plugin) -> Result<(), glib::BoolError> {
@@ -78,6 +94,7 @@ pub fn register(plugin: &gst::Plugin) -> Result<(), glib::BoolError> {
.mark_as_plugin_api(gst::PluginAPIFlags::empty());
AwsTranscriberVocabularyFilterMethod::static_type()
.mark_as_plugin_api(gst::PluginAPIFlags::empty());
+ TranslationSrcPad::static_type().mark_as_plugin_api(gst::PluginAPIFlags::empty());
}
gst::Element::register(
Some(plugin),
diff --git a/net/aws/src/transcriber/transcribe.rs b/net/aws/src/transcriber/transcribe.rs
new file mode 100644
index 00000000..7b683f3b
--- /dev/null
+++ b/net/aws/src/transcriber/transcribe.rs
@@ -0,0 +1,277 @@
+// Copyright (C) 2020 Mathieu Duponchelle <mathieu@centricular.com>
+// Copyright (C) 2023 François Laignel <francois@centricular.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla Public License, v2.0.
+// If a copy of the MPL was not distributed with this file, You can obtain one at
+// <https://mozilla.org/MPL/2.0/>.
+//
+// SPDX-License-Identifier: MPL-2.0
+
+use gst::subclass::prelude::*;
+use gst::{glib, prelude::*};
+
+use aws_sdk_transcribestreaming as aws_transcribe;
+use aws_sdk_transcribestreaming::model;
+
+use futures::channel::mpsc;
+use futures::prelude::*;
+use tokio::sync::broadcast;
+
+use std::sync::Arc;
+
+use super::imp::{Settings, Transcriber};
+use super::CAT;
+
+#[derive(Debug)]
+pub struct TranscriptionSettings {
+ lang_code: model::LanguageCode,
+ sample_rate: i32,
+ vocabulary: Option<String>,
+ vocabulary_filter: Option<String>,
+ vocabulary_filter_method: model::VocabularyFilterMethod,
+ session_id: Option<String>,
+ results_stability: model::PartialResultsStability,
+}
+
+impl TranscriptionSettings {
+ pub(super) fn from(settings: &Settings, sample_rate: i32) -> Self {
+ TranscriptionSettings {
+ lang_code: settings.language_code.as_str().into(),
+ sample_rate,
+ vocabulary: settings.vocabulary.clone(),
+ vocabulary_filter: settings.vocabulary_filter.clone(),
+ vocabulary_filter_method: settings.vocabulary_filter_method.into(),
+ session_id: settings.session_id.clone(),
+ results_stability: settings.results_stability.into(),
+ }
+ }
+}
+
+#[derive(Clone, Debug, Default)]
+pub struct TranscriptItem {
+ pub pts: gst::ClockTime,
+ pub duration: gst::ClockTime,
+ pub content: String,
+ pub is_punctuation: bool,
+}
+
+impl TranscriptItem {
+ pub fn from(item: model::Item, lateness: gst::ClockTime) -> Option<TranscriptItem> {
+ let content = item.content?;
+
+ let start_time = ((item.start_time * 1_000_000_000.0) as u64).nseconds() + lateness;
+ let end_time = ((item.end_time * 1_000_000_000.0) as u64).nseconds() + lateness;
+
+ Some(TranscriptItem {
+ pts: start_time,
+ duration: end_time - start_time,
+ content,
+ is_punctuation: matches!(item.r#type, Some(model::ItemType::Punctuation)),
+ })
+ }
+
+ #[inline]
+ pub fn push(&mut self, item: &TranscriptItem) {
+ self.duration += item.duration;
+
+ self.is_punctuation &= item.is_punctuation;
+ if !item.is_punctuation {
+ self.content.push(' ');
+ }
+
+ self.content.push_str(&item.content);
+ }
+}
+
+#[derive(Clone)]
+pub enum TranscriptEvent {
+ Items(Arc<Vec<TranscriptItem>>),
+ Eos,
+}
+
+impl From<Vec<TranscriptItem>> for TranscriptEvent {
+ fn from(transcript_items: Vec<TranscriptItem>) -> Self {
+ TranscriptEvent::Items(transcript_items.into())
+ }
+}
+
+pub struct TranscriberLoop {
+ imp: glib::subclass::ObjectImplRef<Transcriber>,
+ client: aws_transcribe::Client,
+ settings: Option<TranscriptionSettings>,
+ lateness: gst::ClockTime,
+ buffer_rx: Option<mpsc::Receiver<gst::Buffer>>,
+ transcript_items_tx: broadcast::Sender<TranscriptEvent>,
+ partial_index: usize,
+}
+
+impl TranscriberLoop {
+ pub fn new(
+ imp: &Transcriber,
+ settings: TranscriptionSettings,
+ lateness: gst::ClockTime,
+ buffer_rx: mpsc::Receiver<gst::Buffer>,
+ transcript_items_tx: broadcast::Sender<TranscriptEvent>,
+ ) -> Self {
+ let aws_config = imp.aws_config.lock().unwrap();
+ let aws_config = aws_config
+ .as_ref()
+ .expect("aws_config must be initialized at this stage");
+
+ TranscriberLoop {
+ imp: imp.ref_counted(),
+ client: aws_transcribe::Client::new(aws_config),
+ settings: Some(settings),
+ lateness,
+ buffer_rx: Some(buffer_rx),
+ transcript_items_tx,
+ partial_index: 0,
+ }
+ }
+
+ pub async fn run(mut self) -> Result<(), gst::ErrorMessage> {
+ // Stream the incoming buffers chunked
+ let chunk_stream = self.buffer_rx.take().unwrap().flat_map(move |buffer: gst::Buffer| {
+ async_stream::stream! {
+ let data = buffer.map_readable().unwrap();
+ use aws_transcribe::{model::{AudioEvent, AudioStream}, types::Blob};
+ for chunk in data.chunks(8192) {
+ yield Ok(AudioStream::AudioEvent(AudioEvent::builder().audio_chunk(Blob::new(chunk)).build()));
+ }
+ }
+ });
+
+ let settings = self.settings.take().unwrap();
+ let mut transcribe_builder = self
+ .client
+ .start_stream_transcription()
+ .language_code(settings.lang_code)
+ .media_sample_rate_hertz(settings.sample_rate)
+ .media_encoding(model::MediaEncoding::Pcm)
+ .enable_partial_results_stabilization(true)
+ .partial_results_stability(settings.results_stability)
+ .set_vocabulary_name(settings.vocabulary)
+ .set_session_id(settings.session_id);
+
+ if let Some(vocabulary_filter) = settings.vocabulary_filter {
+ transcribe_builder = transcribe_builder
+ .vocabulary_filter_name(vocabulary_filter)
+ .vocabulary_filter_method(settings.vocabulary_filter_method);
+ }
+
+ let mut output = transcribe_builder
+ .audio_stream(chunk_stream.into())
+ .send()
+ .await
+ .map_err(|err| {
+ let err = format!("Transcribe ws init error: {err}");
+ gst::error!(CAT, imp: self.imp, "{err}");
+ gst::error_msg!(gst::LibraryError::Init, ["{err}"])
+ })?;
+
+ while let Some(event) = output
+ .transcript_result_stream
+ .recv()
+ .await
+ .map_err(|err| {
+ let err = format!("Transcribe ws stream error: {err}");
+ gst::error!(CAT, imp: self.imp, "{err}");
+ gst::error_msg!(gst::LibraryError::Failed, ["{err}"])
+ })?
+ {
+ if let model::TranscriptResultStream::TranscriptEvent(transcript_evt) = event {
+ let mut ready_items = None;
+
+ if let Some(result) = transcript_evt
+ .transcript
+ .and_then(|transcript| transcript.results)
+ .and_then(|mut results| results.drain(..).next())
+ {
+ gst::trace!(CAT, imp: self.imp, "Received: {result:?}");
+
+ if let Some(alternative) = result
+ .alternatives
+ .and_then(|mut alternatives| alternatives.drain(..).next())
+ {
+ ready_items = alternative.items.and_then(|items| {
+ self.get_ready_transcript_items(items, result.is_partial)
+ });
+ }
+ }
+
+ if let Some(ready_items) = ready_items {
+ if self.transcript_items_tx.send(ready_items.into()).is_err() {
+ gst::debug!(CAT, imp: self.imp, "No transcript items receivers");
+ break;
+ }
+ }
+ } else {
+ gst::warning!(
+ CAT,
+ imp: self.imp,
+ "Transcribe ws returned unknown event: consider upgrading the SDK"
+ )
+ }
+ }
+
+ gst::debug!(CAT, imp: self.imp, "Transcriber loop sending EOS");
+ let _ = self.transcript_items_tx.send(TranscriptEvent::Eos);
+
+ gst::debug!(CAT, imp: self.imp, "Exiting transcriber loop");
+
+ Ok(())
+ }
+
+ /// Builds a list from the provided stable items.
+ fn get_ready_transcript_items(
+ &mut self,
+ mut items: Vec<model::Item>,
+ partial: bool,
+ ) -> Option<Vec<TranscriptItem>> {
+ if items.len() <= self.partial_index {
+ gst::error!(
+ CAT,
+ imp: self.imp,
+ "sanity check failed, alternative length {} < partial_index {}",
+ items.len(),
+ self.partial_index
+ );
+
+ if !partial {
+ self.partial_index = 0;
+ }
+
+ return None;
+ }
+
+ let mut output = vec![];
+
+ for item in items.drain(self.partial_index..) {
+ if !item.stable().unwrap_or(false) {
+ break;
+ }
+
+ let Some(item) = TranscriptItem::from(item, self.lateness) else { continue };
+ gst::debug!(
+ CAT,
+ imp: self.imp,
+ "Item is ready for queuing: {}, PTS {}",
+ item.content,
+ item.pts,
+ );
+
+ self.partial_index += 1;
+ output.push(item);
+ }
+
+ if !partial {
+ self.partial_index = 0;
+ }
+
+ if output.is_empty() {
+ return None;
+ }
+
+ Some(output)
+ }
+}
diff --git a/net/aws/src/transcriber/translate.rs b/net/aws/src/transcriber/translate.rs
new file mode 100644
index 00000000..b689bd63
--- /dev/null
+++ b/net/aws/src/transcriber/translate.rs
@@ -0,0 +1,215 @@
+// Copyright (C) 2023 François Laignel <francois@centricular.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla Public License, v2.0.
+// If a copy of the MPL was not distributed with this file, You can obtain one at
+// <https://mozilla.org/MPL/2.0/>.
+//
+// SPDX-License-Identifier: MPL-2.0
+
+use gst::glib;
+use gst::subclass::prelude::*;
+
+use aws_sdk_translate as aws_translate;
+
+use futures::channel::mpsc;
+use futures::prelude::*;
+
+use std::collections::VecDeque;
+
+use super::imp::TranslationSrcPad;
+use super::transcribe::TranscriptItem;
+use super::CAT;
+
+pub struct TranslatedItem {
+ pub pts: gst::ClockTime,
+ pub duration: gst::ClockTime,
+ pub content: String,
+}
+
+impl From<&TranscriptItem> for TranslatedItem {
+ fn from(transcript_item: &TranscriptItem) -> Self {
+ TranslatedItem {
+ pts: transcript_item.pts,
+ duration: transcript_item.duration,
+ content: transcript_item.content.clone(),
+ }
+ }
+}
+
+#[derive(Default)]
+pub struct TranslationQueue {
+ items: VecDeque<TranscriptItem>,
+}
+
+impl TranslationQueue {
+ pub fn is_empty(&self) -> bool {
+ self.items.is_empty()
+ }
+
+ /// Pushes the provided item.
+ ///
+ /// Returns `Some(..)` if items are ready for translation.
+ pub fn push(&mut self, transcript_item: &TranscriptItem) -> Option<TranscriptItem> {
+ // Keep track of the item individually so we can schedule translation precisely.
+ self.items.push_back(transcript_item.clone());
+
+ if transcript_item.is_punctuation {
+ // This makes it a good chunk for translation.
+ // Concatenate as a single item for translation
+
+ let mut items = self.items.drain(..);
+
+ let mut item_acc = items.next()?;
+ for item in items {
+ item_acc.push(&item);
+ }
+
+ item_acc.push(transcript_item);
+
+ return Some(item_acc);
+ }
+
+ // Regular case: no separator detected, don't push transcript items
+ // to translation now. They will be pushed either if a punctuation
+ // is found or of a `dequeue()` is requested.
+
+ None
+ }
+
+ /// Dequeues items from the specified `deadline` up to `lookahead`.
+ ///
+ /// Returns `Some(..)` with the accumulated items matching the criteria.
+ pub fn dequeue(
+ &mut self,
+ deadline: gst::ClockTime,
+ lookahead: gst::ClockTime,
+ ) -> Option<TranscriptItem> {
+ if self.items.front()?.pts < deadline {
+ // First item is too early to be sent to translation now
+ // we can wait for more items to accumulate.
+ return None;
+ }
+
+ // Can't wait any longer to send the first item to translation
+ // Try to get up to lookahead more items to improve translation accuracy
+ let limit = deadline + lookahead;
+
+ let mut item_acc = self.items.pop_front().unwrap();
+ while let Some(item) = self.items.front() {
+ if item.pts > limit {
+ break;
+ }
+
+ let item = self.items.pop_front().unwrap();
+ item_acc.push(&item);
+ }
+
+ Some(item_acc)
+ }
+}
+
+pub struct TranslationLoop {
+ pad: glib::subclass::ObjectImplRef<TranslationSrcPad>,
+ client: aws_translate::Client,
+ input_lang: String,
+ output_lang: String,
+ transcript_rx: mpsc::Receiver<TranscriptItem>,
+ translation_tx: mpsc::Sender<TranslatedItem>,
+}
+
+impl TranslationLoop {
+ pub fn new(
+ imp: &super::imp::Transcriber,
+ pad: &TranslationSrcPad,
+ input_lang: &str,
+ output_lang: &str,
+ transcript_rx: mpsc::Receiver<TranscriptItem>,
+ translation_tx: mpsc::Sender<TranslatedItem>,
+ ) -> Self {
+ let aws_config = imp.aws_config.lock().unwrap();
+ let aws_config = aws_config
+ .as_ref()
+ .expect("aws_config must be initialized at this stage");
+
+ TranslationLoop {
+ pad: pad.ref_counted(),
+ client: aws_sdk_translate::Client::new(aws_config),
+ input_lang: input_lang.to_string(),
+ output_lang: output_lang.to_string(),
+ transcript_rx,
+ translation_tx,
+ }
+ }
+
+ pub async fn check_language(&self) -> Result<(), gst::ErrorMessage> {
+ let language_list = self.client.list_languages().send().await.map_err(|err| {
+ let err = format!("Failed to call list_languages service: {err}");
+ gst::info!(CAT, imp: self.pad, "{err}");
+ gst::error_msg!(gst::LibraryError::Failed, ["{err}"])
+ })?;
+
+ let found_output_lang = language_list
+ .languages()
+ .and_then(|langs| {
+ langs
+ .iter()
+ .find(|lang| lang.language_code() == Some(&self.output_lang))
+ })
+ .is_some();
+
+ if !found_output_lang {
+ let err = format!("Unknown output languages: {}", self.output_lang);
+ gst::info!(CAT, imp: self.pad, "{err}");
+ return Err(gst::error_msg!(gst::LibraryError::Failed, ["{err}"]));
+ }
+
+ Ok(())
+ }
+
+ pub async fn run(mut self) -> Result<(), gst::ErrorMessage> {
+ while let Some(transcript_item) = self.transcript_rx.next().await {
+ let TranscriptItem {
+ pts,
+ duration,
+ content,
+ ..
+ } = transcript_item;
+
+ let translated_text = if content.is_empty() {
+ content
+ } else {
+ self.client
+ .translate_text()
+ .set_source_language_code(Some(self.input_lang.clone()))
+ .set_target_language_code(Some(self.output_lang.clone()))
+ .set_text(Some(content))
+ .send()
+ .await
+ .map_err(|err| {
+ let err = format!("Failed to call translation service: {err}");
+ gst::info!(CAT, imp: self.pad, "{err}");
+ gst::error_msg!(gst::LibraryError::Failed, ["{err}"])
+ })?
+ .translated_text
+ .unwrap_or_default()
+ };
+
+ let translated_item = TranslatedItem {
+ pts,
+ duration,
+ content: translated_text,
+ };
+
+ if self.translation_tx.send(translated_item).await.is_err() {
+ gst::info!(
+ CAT,
+ imp: self.pad,
+ "translation chan terminated, exiting translation loop"
+ );
+ break;
+ }
+ }
+
+ Ok(())
+ }
+}