Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/sdroege/gst-plugin-rs.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/net/aws
diff options
context:
space:
mode:
authorFrançois Laignel <francois@centricular.com>2023-02-24 23:40:54 +0300
committerGStreamer Marge Bot <gitlab-merge-bot@gstreamer-foundation.org>2023-03-01 11:47:58 +0300
commit00153754bb8ce7e98c7c4390feb5c775865ca8bd (patch)
treeb5dc48c567c1bc5d43f0d6990a12c1d07ec22c03 /net/aws
parent57f365979c9dea7c2a0061bf3001e7819c5ec9bf (diff)
net/aws: use aws-sdk-transcribestreaming
Switch from manual webservice client impl to `aws-sdk-transcribestreaming`. Part-of: <https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs/-/merge_requests/1104>
Diffstat (limited to 'net/aws')
-rw-r--r--net/aws/Cargo.toml28
-rw-r--r--net/aws/src/transcriber/imp.rs1003
-rw-r--r--net/aws/src/transcriber/mod.rs25
-rw-r--r--net/aws/src/transcriber/packet/mod.rs174
4 files changed, 448 insertions, 782 deletions
diff --git a/net/aws/Cargo.toml b/net/aws/Cargo.toml
index ef6370e9..c329009f 100644
--- a/net/aws/Cargo.toml
+++ b/net/aws/Cargo.toml
@@ -11,36 +11,30 @@ edition = "2021"
rust-version = "1.66"
[dependencies]
-bytes = "1.0"
-futures = "0.3"
-gst = { package = "gstreamer", git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs" }
-gst-base = { package = "gstreamer-base", git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs" }
-gst-audio = { package = "gstreamer-audio", git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs", features = ["v1_16"] }
+async-stream = "0.3.4"
+base32 = "0.4"
aws-config = "0.54.0"
aws-sdk-s3 = "0.24.0"
-aws-sdk-transcribe = "0.24.0"
+aws-sdk-transcribestreaming = "0.24.0"
aws-types = "0.54.0"
aws-credential-types = "0.54.0"
aws-sig-auth = "0.54.0"
aws-smithy-http = { version = "0.54.0", features = [ "rt-tokio" ] }
aws-smithy-types = "0.54.0"
+bytes = "1.0"
+futures = "0.3"
+gio = { git = "https://github.com/gtk-rs/gtk-rs-core.git", package = "gio" }
+gst = { package = "gstreamer", git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs" }
+gst-base = { package = "gstreamer-base", git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs" }
+gst-audio = { package = "gstreamer-audio", git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs", features = ["v1_16"] }
http = "0.2.7"
-chrono = "0.4"
-url = "2"
+once_cell = "1.0"
percent-encoding = "2"
tokio = { version = "1.0", features = [ "full" ] }
-async-tungstenite = { version = "0.20", features = ["tokio", "tokio-runtime", "tokio-native-tls"] }
-nom = "7"
-crc = "3"
-byteorder = "1.3.4"
-once_cell = "1.0"
serde = "1"
serde_derive = "1"
serde_json = "1"
-atomic_refcell = "0.1"
-base32 = "0.4"
-backoff = { version = "0.4", features = [ "futures", "tokio" ] }
-gio = { git = "https://github.com/gtk-rs/gtk-rs-core.git", package = "gio" }
+url = "2"
[dev-dependencies]
chrono = { version = "0.4", features = [ "alloc" ] }
diff --git a/net/aws/src/transcriber/imp.rs b/net/aws/src/transcriber/imp.rs
index 3918c61c..18e45c08 100644
--- a/net/aws/src/transcriber/imp.rs
+++ b/net/aws/src/transcriber/imp.rs
@@ -1,4 +1,5 @@
// Copyright (C) 2020 Mathieu Duponchelle <mathieu@centricular.com>
+// Copyright (C) 2023 François Laignel <francois@centricular.com>
//
// This Source Code Form is subject to the terms of the Mozilla Public License, v2.0.
// If a copy of the MPL was not distributed with this file, You can obtain one at
@@ -9,91 +10,22 @@
use gst::glib;
use gst::prelude::*;
use gst::subclass::prelude::*;
-use gst::{element_imp_error, error_msg, loggable_error};
-use std::default::Default;
+use aws_sdk_transcribestreaming as aws_transcribe;
+use aws_sdk_transcribestreaming::model;
-use aws_config::default_provider::credentials::DefaultCredentialsChain;
-use aws_credential_types::{provider::ProvideCredentials, Credentials};
-use aws_sig_auth::signer::{self, HttpSignatureType, OperationSigningConfig, RequestConfig};
-use aws_smithy_http::body::SdkBody;
-use aws_types::region::{Region, SigningRegion};
-use aws_types::SigningService;
-use std::time::{Duration, SystemTime};
-
-use chrono::prelude::*;
-use http::Uri;
-
-use async_tungstenite::tungstenite::error::Error as WsError;
-use async_tungstenite::{tokio::connect_async, tungstenite::Message};
use futures::channel::mpsc;
-use futures::future::{abortable, AbortHandle};
use futures::prelude::*;
-use tokio::runtime;
+use tokio::{runtime, task};
use std::cmp::Ordering;
use std::collections::VecDeque;
-use std::pin::Pin;
use std::sync::Mutex;
-use atomic_refcell::AtomicRefCell;
-
-use super::packet::*;
-
-use serde_derive::{Deserialize, Serialize};
-
use once_cell::sync::Lazy;
use super::{AwsTranscriberResultStability, AwsTranscriberVocabularyFilterMethod};
-const DEFAULT_TRANSCRIBER_REGION: &str = "us-east-1";
-
-#[derive(Deserialize, Serialize, Debug)]
-#[serde(rename_all = "PascalCase")]
-struct TranscriptItem {
- content: String,
- end_time: f32,
- start_time: f32,
- #[serde(rename = "Type")]
- type_: String,
- stable: bool,
-}
-
-#[derive(Deserialize, Serialize, Debug)]
-#[serde(rename_all = "PascalCase")]
-struct TranscriptAlternative {
- items: Vec<TranscriptItem>,
- transcript: String,
-}
-
-#[derive(Deserialize, Serialize, Debug)]
-#[serde(rename_all = "PascalCase")]
-struct TranscriptResult {
- alternatives: Vec<TranscriptAlternative>,
- end_time: f32,
- start_time: f32,
- is_partial: bool,
- result_id: String,
-}
-
-#[derive(Deserialize, Debug)]
-#[serde(rename_all = "PascalCase")]
-struct TranscriptTranscript {
- results: Vec<TranscriptResult>,
-}
-
-#[derive(Deserialize, Debug)]
-#[serde(rename_all = "PascalCase")]
-struct Transcript {
- transcript: TranscriptTranscript,
-}
-
-#[derive(Deserialize, Debug)]
-#[serde(rename_all = "PascalCase")]
-struct ExceptionMessage {
- message: String,
-}
-
static CAT: Lazy<gst::DebugCategory> = Lazy::new(|| {
gst::DebugCategory::new(
"awstranscribe",
@@ -110,8 +42,10 @@ static RUNTIME: Lazy<runtime::Runtime> = Lazy::new(|| {
.unwrap()
});
+const DEFAULT_TRANSCRIBER_REGION: &str = "us-east-1";
const DEFAULT_LATENCY: gst::ClockTime = gst::ClockTime::from_seconds(8);
const DEFAULT_LATENESS: gst::ClockTime = gst::ClockTime::ZERO;
+const DEFAULT_LANGUAGE_CODE: &str = "en-US";
const DEFAULT_STABILITY: AwsTranscriberResultStability = AwsTranscriberResultStability::Low;
const DEFAULT_VOCABULARY_FILTER_METHOD: AwsTranscriberVocabularyFilterMethod =
AwsTranscriberVocabularyFilterMethod::Mask;
@@ -121,7 +55,7 @@ const GRANULARITY: gst::ClockTime = gst::ClockTime::from_mseconds(100);
struct Settings {
latency: gst::ClockTime,
lateness: gst::ClockTime,
- language_code: Option<String>,
+ language_code: String,
vocabulary: Option<String>,
session_id: Option<String>,
results_stability: AwsTranscriberResultStability,
@@ -132,12 +66,12 @@ struct Settings {
vocabulary_filter_method: AwsTranscriberVocabularyFilterMethod,
}
-impl Default for Settings {
+impl std::default::Default for Settings {
fn default() -> Self {
Self {
latency: DEFAULT_LATENCY,
lateness: DEFAULT_LATENESS,
- language_code: Some("en-US".to_string()),
+ language_code: DEFAULT_LANGUAGE_CODE.to_string(),
vocabulary: None,
session_id: None,
results_stability: DEFAULT_STABILITY,
@@ -150,29 +84,55 @@ impl Default for Settings {
}
}
+#[derive(Debug)]
+struct TranscriptionSettings {
+ lang_code: model::LanguageCode,
+ sample_rate: i32,
+ vocabulary: Option<String>,
+ vocabulary_filter: Option<String>,
+ vocabulary_filter_method: model::VocabularyFilterMethod,
+ session_id: Option<String>,
+ results_stability: model::PartialResultsStability,
+}
+
+impl TranscriptionSettings {
+ fn from(settings: &Settings, sample_rate: i32) -> Self {
+ TranscriptionSettings {
+ lang_code: settings.language_code.as_str().into(),
+ sample_rate,
+ vocabulary: settings.vocabulary.clone(),
+ vocabulary_filter: settings.vocabulary_filter.clone(),
+ vocabulary_filter_method: settings.vocabulary_filter_method.into(),
+ session_id: settings.session_id.clone(),
+ results_stability: settings.results_stability.into(),
+ }
+ }
+}
+
struct State {
- connected: bool,
- sender: Option<mpsc::Sender<Message>>,
- recv_abort_handle: Option<AbortHandle>,
- send_abort_handle: Option<AbortHandle>,
+ client: Option<aws_transcribe::Client>,
+ buffer_tx: Option<mpsc::Sender<gst::Buffer>>,
+ transcript_tx: Option<mpsc::Sender<model::TranscriptEvent>>,
+ ws_loop_handle: Option<task::JoinHandle<Result<(), gst::ErrorMessage>>>,
in_segment: gst::FormattedSegment<gst::ClockTime>,
out_segment: gst::FormattedSegment<gst::ClockTime>,
seqnum: gst::Seqnum,
buffers: VecDeque<gst::Buffer>,
send_eos: bool,
+ // FIXME never set to true
discont: bool,
partial_index: usize,
send_events: bool,
start_time: Option<gst::ClockTime>,
}
-impl Default for State {
+impl std::default::Default for State {
fn default() -> Self {
Self {
- connected: false,
- sender: None,
- recv_abort_handle: None,
- send_abort_handle: None,
+ client: None,
+ buffer_tx: None,
+ transcript_tx: None,
+ ws_loop_handle: None,
in_segment: gst::FormattedSegment::new(),
out_segment: gst::FormattedSegment::new(),
seqnum: gst::Seqnum::next(),
@@ -186,36 +146,11 @@ impl Default for State {
}
}
-type WsSink = Pin<Box<dyn Sink<Message, Error = WsError> + Send + Sync>>;
-
pub struct Transcriber {
srcpad: gst::Pad,
sinkpad: gst::Pad,
settings: Mutex<Settings>,
state: Mutex<State>,
- ws_sink: AtomicRefCell<Option<WsSink>>,
-}
-
-fn build_packet(payload: &[u8]) -> Vec<u8> {
- let headers = [
- Header {
- name: ":event-type".into(),
- value: "AudioEvent".into(),
- value_type: 7,
- },
- Header {
- name: ":content-type".into(),
- value: "application/octet-stream".into(),
- value_type: 7,
- },
- Header {
- name: ":message-type".into(),
- value: "event".into(),
- value_type: 7,
- },
- ];
-
- encode_packet(payload, &headers).expect("foobar")
}
impl Transcriber {
@@ -223,12 +158,7 @@ impl Transcriber {
/* First, check our pending buffers */
let mut items = vec![];
- let now = match self.obj().current_running_time() {
- Some(now) => now,
- None => {
- return true;
- }
- };
+ let Some(now) = self.obj().current_running_time() else { return true };
let latency = self.settings.lock().unwrap().latency;
@@ -249,9 +179,7 @@ impl Transcriber {
gst::trace!(
CAT,
imp: self,
- "Checking now {} if item is ready for dequeuing, PTS {}, threshold {} vs {}",
- now,
- pts,
+ "Checking now {now} if item is ready for dequeuing, PTS {pts}, threshold {} vs {}",
pts + latency.saturating_sub(3 * GRANULARITY),
now - start_time
);
@@ -295,7 +223,7 @@ impl Transcriber {
.duration(pts - last_position)
.seqnum(seqnum)
.build();
- gst::log!(CAT, "Pushing gap: {} -> {}", last_position, pts);
+ gst::log!(CAT, "Pushing gap: {last_position} -> {pts}");
if !self.srcpad.push_event(gap_event) {
return false;
}
@@ -306,9 +234,7 @@ impl Transcriber {
gst::warning!(
CAT,
imp: self,
- "Updating item PTS ({} < {}), consider increasing latency",
- pts,
- last_position
+ "Updating item PTS ({pts} < {last_position}), consider increasing latency",
);
pts = last_position;
@@ -326,7 +252,7 @@ impl Transcriber {
last_position = pts + duration;
- gst::debug!(CAT, "Pushing buffer: {} -> {}", pts, pts + duration);
+ gst::debug!(CAT, "Pushing buffer: {pts} -> {}", pts + duration);
if self.srcpad.push(buf).is_err() {
return false;
@@ -337,9 +263,7 @@ impl Transcriber {
gst::trace!(
CAT,
imp: self,
- "Checking now: {} if we need to push a gap, last_position: {}, threshold: {}",
- now,
- last_position,
+ "Checking now: {now} if we need to push a gap, last_position: {last_position}, threshold: {}",
last_position + latency.saturating_sub(GRANULARITY)
);
@@ -353,8 +277,7 @@ impl Transcriber {
gst::log!(
CAT,
- "Pushing gap: {} -> {}",
- last_position,
+ "Pushing gap: {last_position} -> {}",
last_position + duration
);
@@ -374,15 +297,15 @@ impl Transcriber {
true
}
- fn enqueue(&self, state: &mut State, alternative: &TranscriptAlternative, partial: bool) {
+ fn enqueue(&self, state: &mut State, items: &[model::Item], partial: bool) {
let lateness = self.settings.lock().unwrap().lateness;
- if alternative.items.len() <= state.partial_index {
+ if items.len() <= state.partial_index {
gst::error!(
CAT,
imp: self,
"sanity check failed, alternative length {} < partial_index {}",
- alternative.items.len(),
+ items.len(),
state.partial_index
);
@@ -393,40 +316,42 @@ impl Transcriber {
return;
}
- for item in &alternative.items[state.partial_index..] {
- let start_time =
- ((item.start_time as f64 * 1_000_000_000.0) as u64).nseconds() + lateness;
- let end_time = ((item.end_time as f64 * 1_000_000_000.0) as u64).nseconds() + lateness;
+ for item in &items[state.partial_index..] {
+ let start_time = ((item.start_time * 1_000_000_000.0) as u64).nseconds() + lateness;
+ let end_time = ((item.end_time * 1_000_000_000.0) as u64).nseconds() + lateness;
- if !item.stable {
+ if !item.stable().unwrap_or(false) {
break;
}
- /* Should be sent now */
- gst::debug!(
- CAT,
- imp: self,
- "Item is ready for queuing: {}, PTS {}",
- item.content,
- start_time
- );
- let mut buf = gst::Buffer::from_mut_slice(item.content.clone().into_bytes());
+ // FIXME could probably just unwrap
+ if let Some(content) = item.content() {
+ /* Should be sent now */
+ gst::debug!(
+ CAT,
+ imp: self,
+ "Item is ready for queuing: {content}, PTS {start_time}",
+ );
- {
- let buf = buf.get_mut().unwrap();
+ let mut buf = gst::Buffer::from_mut_slice(content.to_string().into_bytes());
+ {
+ let buf = buf.get_mut().unwrap();
- if state.discont {
- buf.set_flags(gst::BufferFlags::DISCONT);
- state.discont = false;
- }
+ if state.discont {
+ buf.set_flags(gst::BufferFlags::DISCONT);
+ state.discont = false;
+ }
- buf.set_pts(start_time);
- buf.set_duration(end_time - start_time);
- }
+ buf.set_pts(start_time);
+ buf.set_duration(end_time - start_time);
+ }
- state.partial_index += 1;
+ state.partial_index += 1;
- state.buffers.push_back(buf);
+ state.buffers.push_back(buf);
+ } else {
+ gst::debug!(CAT, imp: self, "None transcript item content");
+ }
}
if !partial {
@@ -434,12 +359,11 @@ impl Transcriber {
}
}
- fn loop_fn(&self, receiver: &mut mpsc::Receiver<Message>) -> Result<(), gst::ErrorMessage> {
+ fn pad_loop_fn(&self, receiver: &mut mpsc::Receiver<model::TranscriptEvent>) -> Result<(), ()> {
let mut events = {
let mut events = vec![];
- let mut state = self.state.lock().unwrap();
-
+ let state = self.state.lock().unwrap();
if state.send_events {
events.push(
gst::event::StreamStart::builder("transcription")
@@ -461,112 +385,71 @@ impl Transcriber {
.seqnum(state.seqnum)
.build(),
);
-
- state.send_events = false;
}
events
};
- for event in events.drain(..) {
- gst::info!(CAT, imp: self, "Sending {:?}", event);
- self.srcpad.push_event(event);
+ if !events.is_empty() {
+ for event in events.drain(..) {
+ gst::info!(CAT, imp: self, "Sending {event:?}");
+ self.srcpad.push_event(event);
+ }
+
+ self.state.lock().unwrap().send_events = false;
}
let future = async move {
- let msg = match receiver.next().await {
- Some(msg) => msg,
- /* Sender was closed */
- None => {
- let _ = self.srcpad.pause_task();
- return Ok(());
- }
- };
-
- match msg {
- Message::Binary(buf) => {
- let (_, pkt) = parse_packet(&buf).map_err(|err| {
- gst::error!(CAT, imp: self, "Failed to parse packet: {}", err);
- error_msg!(
- gst::StreamError::Failed,
- ["Failed to parse packet: {}", err]
- )
- })?;
-
- let payload = std::str::from_utf8(pkt.payload).unwrap();
-
- if packet_is_exception(&pkt) {
- let message: ExceptionMessage =
- serde_json::from_str(payload).map_err(|err| {
- gst::error!(
- CAT,
- imp: self,
- "Unexpected exception message: {} ({})",
- payload,
- err
- );
- error_msg!(
- gst::StreamError::Failed,
- ["Unexpected exception message: {} ({})", payload, err]
- )
- })?;
- gst::error!(CAT, imp: self, "AWS raised an error: {}", message.message);
-
- return Err(error_msg!(
- gst::StreamError::Failed,
- ["AWS raised an error: {}", message.message]
- ));
- }
-
- let transcript: Transcript = serde_json::from_str(payload).map_err(|err| {
- error_msg!(
- gst::StreamError::Failed,
- ["Unexpected binary message: {} ({})", payload, err]
- )
- })?;
+ enum Winner {
+ TranscriptEvent(Option<model::TranscriptEvent>),
+ Timeout,
+ }
- if let Some(result) = transcript.transcript.results.get(0) {
- gst::trace!(
- CAT,
- imp: self,
- "result: {}",
- serde_json::to_string_pretty(&result).unwrap(),
- );
+ let timer = tokio::time::sleep(GRANULARITY.into()).fuse();
+ futures::pin_mut!(timer);
- if let Some(alternative) = result.alternatives.get(0) {
- let mut state = self.state.lock().unwrap();
+ let race_res = futures::select_biased! {
+ transcript_evt = receiver.next() => Winner::TranscriptEvent(transcript_evt),
+ _ = timer => Winner::Timeout,
+ };
- self.enqueue(&mut state, alternative, result.is_partial)
+ use Winner::*;
+ match race_res {
+ TranscriptEvent(Some(transcript_evt)) => {
+ if let Some(result) = transcript_evt
+ .transcript
+ .as_ref()
+ .and_then(|transcript| transcript.results())
+ .and_then(|results| results.get(0))
+ {
+ gst::trace!(CAT, imp: self, "Received: {result:?}");
+
+ if let Some(alternative) = result
+ .alternatives
+ .as_ref()
+ .and_then(|alternatives| alternatives.get(0))
+ {
+ if let Some(items) = alternative.items() {
+ let mut state = self.state.lock().unwrap();
+ self.enqueue(&mut state, items, result.is_partial)
+ }
}
}
-
- Ok(())
}
-
- _ => Ok(()),
- }
- };
-
- /* Wrap in a timeout so we can push gaps regularly */
- let future = async move {
- match tokio::time::timeout(GRANULARITY.into(), future).await {
- Err(_) => {
- if !self.dequeue() {
- gst::info!(CAT, imp: self, "Failed to push gap event, pausing");
-
- let _ = self.srcpad.pause_task();
- }
- Ok(())
+ TranscriptEvent(None) => {
+ gst::info!(CAT, imp: self, "Transcript evt channel disconnected");
+ // Something bad happened elsewhere, let the other side report.
+ return Err(());
}
- Ok(res) => {
- if !self.dequeue() {
- gst::info!(CAT, imp: self, "Failed to push gap event, pausing");
+ Timeout => (),
+ }
- let _ = self.srcpad.pause_task();
- }
- res
- }
+ if !self.dequeue() {
+ gst::info!(CAT, imp: self, "Failed to dequeue buffer, pausing");
+ let _ = self.srcpad.pause_task();
}
+
+ Ok(())
};
let _enter = RUNTIME.enter();
@@ -574,26 +457,53 @@ impl Transcriber {
}
fn start_task(&self) -> Result<(), gst::LoggableError> {
- let (sender, mut receiver) = mpsc::channel(1);
+ let mut state = self.state.lock().unwrap();
- {
- let mut state = self.state.lock().unwrap();
- state.sender = Some(sender);
- }
+ let (transcript_tx, mut transcript_rx) = mpsc::channel(1);
let imp = self.ref_counted();
let res = self.srcpad.start_task(move || {
- if let Err(err) = imp.loop_fn(&mut receiver) {
- element_imp_error!(imp, gst::StreamError::Failed, ["Streaming failed: {}", err]);
+ if imp.pad_loop_fn(&mut transcript_rx).is_err() {
+ // Pad loop fn reported an unrecoverable error.
+ // FIXME we should probably stop the task as
+ // there's nothing we can do about it except restarting.
let _ = imp.srcpad.pause_task();
}
});
+
if res.is_err() {
- return Err(loggable_error!(CAT, "Failed to start pad task"));
+ state.transcript_tx = None;
+ return Err(gst::loggable_error!(CAT, "Failed to start pad task"));
}
+
+ state.transcript_tx = Some(transcript_tx);
+
Ok(())
}
+ fn stop_task(&self) {
+ let mut state = self.state.lock().unwrap();
+
+ let _ = self.srcpad.stop_task();
+
+ if let Some(ws_loop_handle) = state.ws_loop_handle.take() {
+ ws_loop_handle.abort();
+ }
+
+ state.transcript_tx = None;
+ state.buffer_tx = None;
+ }
+
+ fn stop_ws_loop(&self) {
+ let mut state = self.state.lock().unwrap();
+
+ if let Some(ws_loop_handle) = state.ws_loop_handle.take() {
+ ws_loop_handle.abort();
+ }
+
+ state.buffer_tx = None;
+ }
+
fn src_activatemode(
&self,
_pad: &gst::Pad,
@@ -603,24 +513,18 @@ impl Transcriber {
if active {
self.start_task()?;
} else {
- {
- let mut state = self.state.lock().unwrap();
- state.sender = None;
- }
-
- let _ = self.srcpad.stop_task();
+ self.stop_task();
}
Ok(())
}
fn src_query(&self, pad: &gst::Pad, query: &mut gst::QueryRef) -> bool {
- use gst::QueryViewMut;
-
- gst::log!(CAT, obj: pad, "Handling query {:?}", query);
+ gst::log!(CAT, obj: pad, "Handling query {query:?}");
+ use gst::QueryViewMut::*;
match query.view_mut() {
- QueryViewMut::Latency(q) => {
+ Latency(q) => {
let mut peer_query = gst::query::Latency::new();
let ret = self.sinkpad.peer_query(&mut peer_query);
@@ -632,7 +536,7 @@ impl Transcriber {
}
ret
}
- QueryViewMut::Position(q) => {
+ Position(q) => {
if q.format() == gst::Format::Time {
let state = self.state.lock().unwrap();
q.set(
@@ -650,44 +554,29 @@ impl Transcriber {
}
fn sink_event(&self, pad: &gst::Pad, event: gst::Event) -> bool {
- use gst::EventView;
-
- gst::log!(CAT, obj: pad, "Handling event {:?}", event);
+ gst::log!(CAT, obj: pad, "Handling event {event:?}");
+ use gst::EventView::*;
match event.view() {
- EventView::Eos(_) => match self.handle_buffer(pad, None) {
- Err(err) => {
- gst::error!(CAT, "Failed to send EOS to AWS: {}", err);
- false
- }
- Ok(_) => true,
- },
- EventView::FlushStart(_) => {
- gst::info!(CAT, imp: self, "Received flush start, disconnecting");
- let mut ret = gst::Pad::event_default(pad, Some(&*self.obj()), event);
-
- match self.srcpad.stop_task() {
- Err(err) => {
- gst::error!(CAT, imp: self, "Failed to stop srcpad task: {}", err);
+ Eos(_) => {
+ self.stop_ws_loop();
- self.disconnect();
-
- ret = false;
- }
- Ok(_) => {
- self.disconnect();
- }
- };
+ true
+ }
+ FlushStart(_) => {
+ gst::info!(CAT, imp: self, "Received flush start, disconnecting");
+ let ret = gst::Pad::event_default(pad, Some(&*self.obj()), event);
+ self.stop_task();
ret
}
- EventView::FlushStop(_) => {
+ FlushStop(_) => {
gst::info!(CAT, imp: self, "Received flush stop, restarting task");
if gst::Pad::event_default(pad, Some(&*self.obj()), event) {
match self.start_task() {
Err(err) => {
- gst::error!(CAT, imp: self, "Failed to start srcpad task: {}", err);
+ gst::error!(CAT, imp: self, "Failed to start srcpad task: {err}");
false
}
Ok(_) => true,
@@ -696,13 +585,13 @@ impl Transcriber {
false
}
}
- EventView::Segment(e) => {
+ Segment(e) => {
let segment = match e.segment().clone().downcast::<gst::ClockTime>() {
Err(segment) => {
- element_imp_error!(
+ gst::element_imp_error!(
self,
gst::StreamError::Format,
- ["Only Time segments supported, got {:?}", segment.format(),]
+ ["Only Time segments supported, got {:?}", segment.format()]
);
return false;
}
@@ -716,355 +605,285 @@ impl Transcriber {
true
}
- EventView::Tag(_) => true,
- EventView::Caps(e) => {
- gst::info!(CAT, "Received caps {:?}", e);
+ Tag(_) => true,
+ Caps(c) => {
+ gst::info!(CAT, "Received caps {c:?}");
true
}
- EventView::StreamStart(_) => true,
+ StreamStart(_) => true,
_ => gst::Pad::event_default(pad, Some(&*self.obj()), event),
}
}
- async fn sync_and_send(
- &self,
- buffer: Option<gst::Buffer>,
- ) -> Result<gst::FlowSuccess, gst::FlowError> {
- let mut delay = None;
-
- {
- let state = self.state.lock().unwrap();
-
- if let Some(buffer) = &buffer {
- let running_time = state.in_segment.to_running_time(buffer.pts());
- let now = self.obj().current_running_time();
-
- delay = running_time.opt_checked_sub(now).ok().flatten();
- }
- }
-
- if let Some(delay) = delay {
- tokio::time::sleep(delay.into()).await;
- }
-
- if let Some(ws_sink) = self.ws_sink.borrow_mut().as_mut() {
- if let Some(buffer) = buffer {
- let data = buffer.map_readable().unwrap();
- for chunk in data.chunks(8192) {
- let packet = build_packet(chunk);
- ws_sink.send(Message::Binary(packet)).await.map_err(|err| {
- gst::error!(CAT, imp: self, "Failed sending packet: {}", err);
- gst::FlowError::Error
- })?;
- }
- } else {
- // EOS
- let packet = build_packet(&[]);
- ws_sink.send(Message::Binary(packet)).await.map_err(|err| {
- gst::error!(CAT, imp: self, "Failed sending packet: {}", err);
- gst::FlowError::Error
- })?;
- }
- }
-
- Ok(gst::FlowSuccess::Ok)
- }
-
- fn handle_buffer(
+ fn sink_chain(
&self,
- _pad: &gst::Pad,
- buffer: Option<gst::Buffer>,
+ pad: &gst::Pad,
+ buffer: gst::Buffer,
) -> Result<gst::FlowSuccess, gst::FlowError> {
- gst::log!(CAT, imp: self, "Handling {:?}", buffer);
+ gst::log!(CAT, obj: pad, "Handling {buffer:?}");
self.ensure_connection().map_err(|err| {
- element_imp_error!(
- self,
- gst::StreamError::Failed,
- ["Streaming failed: {}", err]
- );
+ gst::element_imp_error!(self, gst::StreamError::Failed, ["Streaming failed: {err}"]);
gst::FlowError::Error
})?;
- let (future, abort_handle) = abortable(self.sync_and_send(buffer));
-
- self.state.lock().unwrap().send_abort_handle = Some(abort_handle);
-
- let res = {
- let _enter = RUNTIME.enter();
- futures::executor::block_on(future)
+ let Some(mut buffer_tx) = self.state.lock().unwrap().buffer_tx.take() else {
+ gst::log!(CAT, obj: pad, "Flushing");
+ return Err(gst::FlowError::Flushing);
};
- match res {
- Err(_) => Err(gst::FlowError::Flushing),
- Ok(res) => res,
- }
- }
+ futures::executor::block_on(buffer_tx.send(buffer)).map_err(|err| {
+ gst::element_imp_error!(self, gst::StreamError::Failed, ["Streaming failed: {err}"]);
+ gst::FlowError::Error
+ })?;
- fn sink_chain(
- &self,
- pad: &gst::Pad,
- buffer: gst::Buffer,
- ) -> Result<gst::FlowSuccess, gst::FlowError> {
- self.handle_buffer(pad, Some(buffer))
+ self.state.lock().unwrap().buffer_tx = Some(buffer_tx);
+
+ Ok(gst::FlowSuccess::Ok)
}
fn ensure_connection(&self) -> Result<(), gst::ErrorMessage> {
- let state = self.state.lock().unwrap();
-
- if state.connected {
- return Ok(());
+ enum ClientStage {
+ Ready(aws_transcribe::Client),
+ NotReady {
+ access_key: Option<String>,
+ secret_access_key: Option<String>,
+ session_token: Option<String>,
+ },
}
- let in_caps = self.sinkpad.current_caps().unwrap();
- let s = in_caps.structure(0).unwrap();
- let sample_rate = s.get::<i32>("rate").unwrap();
+ let (client_stage, transcription_settings, transcript_tx) = {
+ let mut state = self.state.lock().unwrap();
- let settings = self.settings.lock().unwrap();
+ if let Some(ref ws_loop_handle) = state.ws_loop_handle {
+ if ws_loop_handle.is_finished() {
+ state.ws_loop_handle = None;
- if settings.latency + settings.lateness <= 2 * GRANULARITY {
- gst::error!(
- CAT,
- imp: self,
- "latency + lateness must be greater than 200 milliseconds"
- );
- return Err(error_msg!(
- gst::LibraryError::Settings,
- ["latency + lateness must be greater than 200 milliseconds"]
- ));
- }
+ const ERR: &str = "ws loop terminated unexpectedly";
+ gst::error!(CAT, imp: self, "{ERR}");
+ return Err(gst::error_msg!(gst::LibraryError::Failed, ["{ERR}"]));
+ }
- gst::info!(CAT, imp: self, "Connecting ..");
+ return Ok(());
+ }
- let region = Region::new(DEFAULT_TRANSCRIBER_REGION);
- let access_key = settings.access_key.as_ref();
- let secret_access_key = settings.secret_access_key.as_ref();
- let session_token = settings.session_token.clone();
+ let transcript_tx = state
+ .transcript_tx
+ .take()
+ .expect("attempting to spawn the ws loop, but the srcpad task hasn't been started");
- let credentials = match (access_key, secret_access_key) {
- (Some(key), Some(secret_key)) => {
- gst::debug!(
- CAT,
- imp: self,
- "Using provided access and secret access key"
- );
- Ok(Credentials::new(
- key.clone(),
- secret_key.clone(),
- session_token,
- None,
- "transcribe",
- ))
- }
- _ => {
- gst::debug!(CAT, imp: self, "Using default AWS credentials");
- let cred_future = async {
- let cred = DefaultCredentialsChain::builder()
- .region(region.clone())
- .build()
- .await;
- cred.provide_credentials().await
- };
+ let settings = self.settings.lock().unwrap();
- RUNTIME.block_on(cred_future)
+ if settings.latency + settings.lateness <= 2 * GRANULARITY {
+ const ERR: &str = "latency + lateness must be greater than 200 milliseconds";
+ gst::error!(CAT, imp: self, "{ERR}");
+ return Err(gst::error_msg!(gst::LibraryError::Settings, ["{ERR}"]));
}
- };
- if let Err(e) = credentials {
- return Err(error_msg!(
- gst::LibraryError::Settings,
- ["Failed to retrieve credentials with error {}", e]
- ));
- }
+ let in_caps = self.sinkpad.current_caps().unwrap();
+ let s = in_caps.structure(0).unwrap();
+ let sample_rate = s.get::<i32>("rate").unwrap();
- let current_time = Utc::now();
+ let transcription_settings = TranscriptionSettings::from(&settings, sample_rate);
- let mut query_params = String::from("/stream-transcription-websocket?");
+ let client_stage = if let Some(client) = state.client.take() {
+ ClientStage::Ready(client)
+ } else {
+ ClientStage::NotReady {
+ access_key: settings.access_key.to_owned(),
+ secret_access_key: settings.secret_access_key.to_owned(),
+ session_token: settings.session_token.to_owned(),
+ }
+ };
- let language_code = settings
- .language_code
- .as_ref()
- .expect("Language code is required");
+ (client_stage, transcription_settings, transcript_tx)
+ };
- query_params.push_str(
- format!(
- "language-code={}&media-encoding=pcm&sample-rate={}",
- language_code,
- &sample_rate.to_string(),
- )
- .as_str(),
- );
+ let client = match client_stage {
+ ClientStage::Ready(client) => client,
+ ClientStage::NotReady {
+ access_key,
+ secret_access_key,
+ session_token,
+ } => {
+ gst::info!(CAT, imp: self, "Connecting...");
+ let _enter_guard = RUNTIME.enter();
+
+ let config_loader = match (access_key, secret_access_key) {
+ (Some(key), Some(secret_key)) => {
+ gst::debug!(CAT, imp: self, "Using settings credentials");
+ aws_config::ConfigLoader::default().credentials_provider(
+ aws_transcribe::Credentials::new(
+ key,
+ secret_key,
+ session_token,
+ None,
+ "translate",
+ ),
+ )
+ }
+ _ => {
+ gst::debug!(CAT, imp: self, "Attempting to get credentials from env...");
+ aws_config::from_env()
+ }
+ };
- if let Some(ref vocabulary) = settings.vocabulary {
- query_params.push_str(format!("&vocabulary-name={vocabulary}").as_str());
- }
+ let config_loader = config_loader.region(
+ aws_config::meta::region::RegionProviderChain::default_provider()
+ .or_else(DEFAULT_TRANSCRIBER_REGION),
+ );
+ let config = futures::executor::block_on(config_loader.load());
+ gst::debug!(CAT, imp: self, "Using region {}", config.region().unwrap());
- if let Some(ref vocabulary_filter) = settings.vocabulary_filter {
- query_params.push_str(format!("&vocabulary-filter-name={vocabulary_filter}").as_str());
+ aws_transcribe::Client::new(&config)
+ }
+ };
- query_params.push_str(
- format!(
- "&vocabulary-filter-method={}",
- match settings.vocabulary_filter_method {
- AwsTranscriberVocabularyFilterMethod::Mask => "mask",
- AwsTranscriberVocabularyFilterMethod::Remove => "remove",
- AwsTranscriberVocabularyFilterMethod::Tag => "tag",
- }
- )
- .as_str(),
- );
- }
+ let mut state = self.state.lock().unwrap();
- if let Some(ref session_id) = settings.session_id {
- gst::debug!(CAT, imp: self, "Using session ID: {}", session_id);
- query_params.push_str(format!("&session-id={session_id}").as_str());
- }
+ let (buffer_tx, buffer_rx) = mpsc::channel(1);
+ let ws_loop_handle = RUNTIME.spawn(self.build_ws_loop_fut(
+ client,
+ transcription_settings,
+ buffer_rx,
+ transcript_tx,
+ ));
- query_params.push_str("&enable-partial-results-stabilization=true");
+ state.ws_loop_handle = Some(ws_loop_handle);
+ state.buffer_tx = Some(buffer_tx);
- query_params.push_str(
- format!(
- "&partial-results-stability={}",
- match settings.results_stability {
- AwsTranscriberResultStability::High => "high",
- AwsTranscriberResultStability::Medium => "medium",
- AwsTranscriberResultStability::Low => "low",
- }
- )
- .as_str(),
- );
+ Ok(())
+ }
- drop(settings);
- drop(state);
+ fn build_ws_loop_fut(
+ &self,
+ client: aws_transcribe::Client,
+ settings: TranscriptionSettings,
+ buffer_rx: mpsc::Receiver<gst::Buffer>,
+ transcript_tx: mpsc::Sender<model::TranscriptEvent>,
+ ) -> impl Future<Output = Result<(), gst::ErrorMessage>> {
+ let imp_weak = self.downgrade();
+ async move {
+ use gst::glib::subclass::ObjectImplWeakRef;
+
+ // Guard that restores client & transcript_tx when the ws loop is done
+ struct Guard {
+ imp_weak: ObjectImplWeakRef<Transcriber>,
+ client: Option<aws_transcribe::Client>,
+ transcript_tx: Option<mpsc::Sender<model::TranscriptEvent>>,
+ }
- let signer = signer::SigV4Signer::new();
- let mut operation_config = OperationSigningConfig::default_config();
- operation_config.signature_type = HttpSignatureType::HttpRequestQueryParams;
- operation_config.expires_in = Some(Duration::from_secs(5 * 60)); // See commit a3db85d.
+ impl Guard {
+ fn client(&self) -> &aws_transcribe::Client {
+ self.client.as_ref().unwrap()
+ }
- let request_config = RequestConfig {
- request_ts: SystemTime::from(current_time),
- region: &SigningRegion::from(region.clone()),
- service: &SigningService::from_static("transcribe"),
- payload_override: None,
- };
- let transcribe_uri = Uri::builder()
- .scheme("https")
- .authority(format!("transcribestreaming.{region}.amazonaws.com:8443").as_str())
- .path_and_query(query_params.clone())
- .build()
- .map_err(|err| {
- gst::error!(CAT, imp: self, "Failed to build HTTP request URI: {}", err);
- error_msg!(
- gst::CoreError::Failed,
- ["Failed to build HTTP request URI: {}", err]
- )
- })?;
- let mut request = http::Request::builder()
- .uri(transcribe_uri)
- .body(SdkBody::empty())
- .expect("Failed to build valid request");
- let _signature = signer
- .sign(
- &operation_config,
- &request_config,
- &credentials.unwrap(),
- &mut request,
- )
- .map_err(|err| {
- gst::error!(CAT, imp: self, "Failed to sign HTTP request: {}", err);
- error_msg!(
- gst::CoreError::Failed,
- ["Failed to sign HTTP request: {}", err]
- )
- })?;
- let url = request.uri().to_string();
-
- let (ws, _) = {
- let _enter = RUNTIME.enter();
- futures::executor::block_on(connect_async(format!("wss{}", &url[5..]))).map_err(
- |err| {
- gst::error!(CAT, imp: self, "Failed to connect: {}", err);
- error_msg!(gst::CoreError::Failed, ["Failed to connect: {}", err])
- },
- )?
- };
+ fn transcript_tx(&mut self) -> &mut mpsc::Sender<model::TranscriptEvent> {
+ self.transcript_tx.as_mut().unwrap()
+ }
+ }
- let (ws_sink, mut ws_stream) = ws.split();
+ impl Drop for Guard {
+ fn drop(&mut self) {
+ if let Some(imp) = self.imp_weak.upgrade() {
+ let mut state = imp.state.lock().unwrap();
+ state.client = self.client.take();
+ state.transcript_tx = self.transcript_tx.take();
+ }
+ }
+ }
- *self.ws_sink.borrow_mut() = Some(Box::pin(ws_sink));
+ let mut guard = Guard {
+ imp_weak: imp_weak.clone(),
+ client: Some(client),
+ transcript_tx: Some(transcript_tx),
+ };
- let imp_weak = self.downgrade();
- let future = async move {
- while let Some(transcribe) = imp_weak.upgrade() {
- let msg = match ws_stream.next().await {
- Some(msg) => msg,
- None => {
- let mut state = transcribe.state.lock().unwrap();
- state.send_eos = true;
- break;
+ // Stream the incoming buffers chunked
+ let chunk_stream = buffer_rx.flat_map(move |buffer: gst::Buffer| {
+ async_stream::stream! {
+ let data = buffer.map_readable().unwrap();
+ use aws_transcribe::{model::{AudioEvent, AudioStream}, types::Blob};
+ for chunk in data.chunks(8192) {
+ yield Ok(AudioStream::AudioEvent(AudioEvent::builder().audio_chunk(Blob::new(chunk)).build()));
}
- };
+ }
+ });
+
+ let mut transcribe_builder = guard
+ .client()
+ .start_stream_transcription()
+ .language_code(settings.lang_code)
+ .media_sample_rate_hertz(settings.sample_rate)
+ .media_encoding(model::MediaEncoding::Pcm)
+ .enable_partial_results_stabilization(true)
+ .partial_results_stability(settings.results_stability)
+ .set_vocabulary_name(settings.vocabulary)
+ .set_session_id(settings.session_id);
+
+ if let Some(vocabulary_filter) = settings.vocabulary_filter {
+ transcribe_builder = transcribe_builder
+ .vocabulary_filter_name(vocabulary_filter)
+ .vocabulary_filter_method(settings.vocabulary_filter_method);
+ }
- let msg = match msg {
- Ok(msg) => msg,
- Err(err) => {
- gst::error!(CAT, imp: transcribe, "Failed to receive data: {}", err);
- element_imp_error!(
- transcribe,
- gst::StreamError::Failed,
- ["Streaming failed: {}", err]
- );
- break;
+ let mut output = transcribe_builder
+ .audio_stream(chunk_stream.into())
+ .send()
+ .await
+ .map_err(|err| {
+ let err = format!("Transcribe ws init error: {err}");
+ if let Some(imp) = imp_weak.upgrade() {
+ gst::error!(CAT, imp: imp, "{err}");
}
- };
-
- let mut sender = transcribe.state.lock().unwrap().sender.clone();
+ gst::error_msg!(gst::LibraryError::Init, ["{err}"])
+ })?;
- if let Some(sender) = sender.as_mut() {
- if sender.send(msg).await.is_err() {
+ while let Some(event) = output
+ .transcript_result_stream
+ .recv()
+ .await
+ .map_err(|err| {
+ let err = format!("Transcribe ws stream error: {err}");
+ if let Some(imp) = imp_weak.upgrade() {
+ gst::error!(CAT, imp: imp, "{err}");
+ }
+ gst::error_msg!(gst::LibraryError::Failed, ["{err}"])
+ })?
+ {
+ if let model::TranscriptResultStream::TranscriptEvent(transcript_evt) = event {
+ if guard.transcript_tx().send(transcript_evt).await.is_err() {
+ if let Some(imp) = imp_weak.upgrade() {
+ gst::debug!(CAT, imp: imp, "Terminated transcript_evt channel");
+ }
break;
}
+ } else if let Some(imp) = imp_weak.upgrade() {
+ gst::warning!(
+ CAT,
+ imp: imp,
+ "Transcribe ws returned unknown event: consider upgrading the SDK"
+ )
+ } else {
+ // imp has left the building
+ break;
}
}
- };
-
- let mut state = self.state.lock().unwrap();
-
- let (future, abort_handle) = abortable(future);
-
- state.recv_abort_handle = Some(abort_handle);
-
- RUNTIME.spawn(future);
- state.connected = true;
-
- gst::info!(CAT, imp: self, "Connected");
+ if let Some(imp) = imp_weak.upgrade() {
+ gst::debug!(CAT, imp: imp, "Exiting ws loop");
+ }
- Ok(())
+ Ok(())
+ }
}
fn disconnect(&self) {
let mut state = self.state.lock().unwrap();
-
gst::info!(CAT, imp: self, "Unpreparing");
-
- if let Some(abort_handle) = state.recv_abort_handle.take() {
- abort_handle.abort();
- }
-
- if let Some(abort_handle) = state.send_abort_handle.take() {
- abort_handle.abort();
- }
-
+ self.stop_task();
*state = State::default();
-
- gst::info!(
- CAT,
- imp: self,
- "Unprepared, connected: {}!",
- state.connected
- );
+ gst::info!(CAT, imp: self, "Unprepared");
}
}
@@ -1098,7 +917,12 @@ impl ObjectSubclass for Transcriber {
.activatemode_function(|pad, parent, mode, active| {
Transcriber::catch_panic_pad_function(
parent,
- || Err(loggable_error!(CAT, "Panic activating src pad with mode")),
+ || {
+ Err(gst::loggable_error!(
+ CAT,
+ "Panic activating src pad with mode"
+ ))
+ },
|transcriber| transcriber.src_activatemode(pad, mode, active),
)
})
@@ -1119,7 +943,6 @@ impl ObjectSubclass for Transcriber {
sinkpad,
settings,
state: Mutex::new(State::default()),
- ws_sink: AtomicRefCell::new(None),
}
}
}
@@ -1133,7 +956,7 @@ impl ObjectImpl for Transcriber {
.blurb("The Language of the Stream, see \
<https://docs.aws.amazon.com/transcribe/latest/dg/how-streaming-transcription.html> \
for an up to date list of allowed languages")
- .default_value(Some("en-US"))
+ .default_value(Some(DEFAULT_LANGUAGE_CODE))
.mutable_ready()
.build(),
glib::ParamSpecUInt::builder("latency")
@@ -1325,7 +1148,7 @@ impl ElementImpl for Transcriber {
"Transcriber",
"Audio/Text/Filter",
"Speech to Text filter, using AWS transcribe",
- "Jordan Petridis <jordan@centricular.com>, Mathieu Duponchelle <mathieu@centricular.com>",
+ "Jordan Petridis <jordan@centricular.com>, Mathieu Duponchelle <mathieu@centricular.com>, François Laignel <francois@centricular.com>",
)
});
@@ -1368,7 +1191,7 @@ impl ElementImpl for Transcriber {
&self,
transition: gst::StateChange,
) -> Result<gst::StateChangeSuccess, gst::StateChangeError> {
- gst::info!(CAT, imp: self, "Changing state {:?}", transition);
+ gst::info!(CAT, imp: self, "Changing state {transition:?}");
let mut success = self.parent_change_state(transition)?;
diff --git a/net/aws/src/transcriber/mod.rs b/net/aws/src/transcriber/mod.rs
index db04c9ac..69ac6059 100644
--- a/net/aws/src/transcriber/mod.rs
+++ b/net/aws/src/transcriber/mod.rs
@@ -10,7 +10,8 @@ use gst::glib;
use gst::prelude::*;
mod imp;
-mod packet;
+
+use aws_sdk_transcribestreaming::model::{PartialResultsStability, VocabularyFilterMethod};
#[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Clone, Copy, glib::Enum)]
#[repr(u32)]
@@ -31,6 +32,17 @@ pub enum AwsTranscriberResultStability {
Low = 2,
}
+impl From<AwsTranscriberResultStability> for PartialResultsStability {
+ fn from(val: AwsTranscriberResultStability) -> Self {
+ use AwsTranscriberResultStability::*;
+ match val {
+ High => PartialResultsStability::High,
+ Medium => PartialResultsStability::Medium,
+ Low => PartialResultsStability::Low,
+ }
+ }
+}
+
#[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Clone, Copy, glib::Enum)]
#[repr(u32)]
#[enum_type(name = "GstAwsTranscriberVocabularyFilterMethod")]
@@ -44,6 +56,17 @@ pub enum AwsTranscriberVocabularyFilterMethod {
Tag = 2,
}
+impl From<AwsTranscriberVocabularyFilterMethod> for VocabularyFilterMethod {
+ fn from(val: AwsTranscriberVocabularyFilterMethod) -> Self {
+ use AwsTranscriberVocabularyFilterMethod::*;
+ match val {
+ Mask => VocabularyFilterMethod::Mask,
+ Remove => VocabularyFilterMethod::Remove,
+ Tag => VocabularyFilterMethod::Tag,
+ }
+ }
+}
+
glib::wrapper! {
pub struct Transcriber(ObjectSubclass<imp::Transcriber>) @extends gst::Element, gst::Object;
}
diff --git a/net/aws/src/transcriber/packet/mod.rs b/net/aws/src/transcriber/packet/mod.rs
deleted file mode 100644
index d11a054c..00000000
--- a/net/aws/src/transcriber/packet/mod.rs
+++ /dev/null
@@ -1,174 +0,0 @@
-// Copyright (C) 2020 Jordan Petridis <jordan@centricular.com>
-//
-// This Source Code Form is subject to the terms of the Mozilla Public License, v2.0.
-// If a copy of the MPL was not distributed with this file, You can obtain one at
-// <https://mozilla.org/MPL/2.0/>.
-//
-// SPDX-License-Identifier: MPL-2.0
-
-use byteorder::{BigEndian, WriteBytesExt};
-use nom::{
- self, bytes::complete::take, combinator::map_res, multi::many0, number::complete::be_u16,
- number::complete::be_u32, number::complete::be_u8, sequence::tuple, IResult,
-};
-use std::borrow::Cow;
-use std::io::{self, Write};
-
-const CRC: crc::Crc<u32> = crc::Crc::<u32>::new(&crc::CRC_32_ISO_HDLC);
-
-#[derive(Debug)]
-struct Prelude {
- total_bytes: u32,
- header_bytes: u32,
- #[allow(dead_code)]
- prelude_crc: u32,
-}
-
-#[derive(Debug)]
-pub struct Header {
- pub name: Cow<'static, str>,
- pub value_type: u8,
- pub value: Cow<'static, str>,
-}
-
-#[derive(Debug)]
-pub struct Packet<'a> {
- #[allow(dead_code)]
- prelude: Prelude,
- headers: Vec<Header>,
- pub payload: &'a [u8],
- #[allow(dead_code)]
- msg_crc: u32,
-}
-
-fn write_header<W: Write>(w: &mut W, header: &Header) -> Result<(), io::Error> {
- w.write_u8(header.name.len() as u8)?;
- w.write_all(header.name.as_bytes())?;
- w.write_u8(header.value_type)?;
- w.write_u16::<BigEndian>(header.value.len() as u16)?;
- w.write_all(header.value.as_bytes())?;
- Ok(())
-}
-
-fn write_headers<W: Write>(w: &mut W, headers: &[Header]) -> Result<(), io::Error> {
- for header in headers {
- write_header(w, header)?;
- }
- Ok(())
-}
-
-pub fn encode_packet(payload: &[u8], headers: &[Header]) -> Result<Vec<u8>, io::Error> {
- let mut res = Vec::with_capacity(1024);
-
- // Total length
- res.write_u32::<BigEndian>(0)?;
- // Header length
- res.write_u32::<BigEndian>(0)?;
- // Prelude CRC32 placeholder
- res.write_u32::<BigEndian>(0)?;
-
- // Write all headers
- write_headers(&mut res, headers)?;
-
- // Rewrite header length
- let header_length = res.len() - 12;
- (&mut res[4..8]).write_u32::<BigEndian>(header_length as u32)?;
-
- // Write payload
- res.write_all(payload)?;
-
- // Rewrite total length
- let total_length = res.len() + 4;
- (&mut res[0..4]).write_u32::<BigEndian>(total_length as u32)?;
-
- // Rewrite the prelude crc since we replaced the lengths
- let prelude_crc = CRC.checksum(&res[0..8]);
- (&mut res[8..12]).write_u32::<BigEndian>(prelude_crc)?;
-
- // Message CRC
- let message_crc = CRC.checksum(&res);
- res.write_u32::<BigEndian>(message_crc)?;
-
- Ok(res)
-}
-
-fn parse_prelude(input: &[u8]) -> IResult<&[u8], Prelude> {
- map_res(
- tuple((be_u32, be_u32, be_u32)),
- |(total_bytes, header_bytes, prelude_crc)| {
- let sum = CRC.checksum(&input[0..8]);
- if prelude_crc != sum {
- return Err(nom::Err::Error((
- "Prelude CRC doesn't match",
- nom::error::ErrorKind::MapRes,
- )));
- }
-
- Ok(Prelude {
- total_bytes,
- header_bytes,
- prelude_crc,
- })
- },
- )(input)
-}
-
-fn parse_header(input: &[u8]) -> IResult<&[u8], Header> {
- let (input, header_length) = be_u8(input)?;
- let (input, name) = map_res(take(header_length), std::str::from_utf8)(input)?;
- let (input, value_type) = be_u8(input)?;
- let (input, value_length) = be_u16(input)?;
- let (input, value) = map_res(take(value_length), std::str::from_utf8)(input)?;
-
- let header = Header {
- name: name.to_string().into(),
- value_type,
- value: value.to_string().into(),
- };
-
- Ok((input, header))
-}
-
-pub fn packet_is_exception(packet: &Packet) -> bool {
- for header in &packet.headers {
- if header.name == ":message-type" && header.value == "exception" {
- return true;
- }
- }
-
- false
-}
-
-pub fn parse_packet(input: &[u8]) -> IResult<&[u8], Packet> {
- let (remainder, prelude) = parse_prelude(input)?;
-
- // Check the crc of the whole input
- let sum = CRC.checksum(&input[..input.len() - 4]);
- let (_, msg_crc) = be_u32(&input[input.len() - 4..])?;
-
- if msg_crc != sum {
- return Err(nom::Err::Error(nom::error::Error::new(
- b"Prelude CRC doesn't match",
- nom::error::ErrorKind::MapRes,
- )));
- }
-
- let (remainder, header_input) = take(prelude.header_bytes)(remainder)?;
- let (_, headers) = many0(parse_header)(header_input)?;
-
- let payload_length = prelude.total_bytes - prelude.header_bytes - 4 - 12;
- let (remainder, payload) = take(payload_length)(remainder)?;
-
- // only the message_crc we check before should be remaining now
- assert_eq!(remainder.len(), 4);
-
- Ok((
- input,
- Packet {
- prelude,
- headers,
- payload,
- msg_crc,
- },
- ))
-}