Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/sdroege/gst-plugin-rs.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/net/aws
diff options
context:
space:
mode:
authorFrançois Laignel <francois@centricular.com>2023-03-01 11:44:51 +0300
committerGStreamer Marge Bot <gitlab-merge-bot@gstreamer-foundation.org>2023-03-01 11:47:58 +0300
commit4a988aaeb8e650b3680514025fe7d52c6b10c659 (patch)
tree7cee38ee2d7d18bbf3f12eafb7437988d4cd3336 /net/aws
parentf1a080c94e5ec4c5b9f21b1606dd507ccd34cea5 (diff)
net/aws/transcriber: use a TranscriberLoop struct
This helps gather together the details related to the `TranscriberLoop`. One difference with previous implementation is that the ws `Client` is build each time the loop is started instead of being reused. With the new approach, we don't keep the connection open after EOS and we should be more resistant in case of a connection failure. Part-of: <https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs/-/merge_requests/1104>
Diffstat (limited to 'net/aws')
-rw-r--r--net/aws/src/transcriber/imp.rs297
1 files changed, 136 insertions, 161 deletions
diff --git a/net/aws/src/transcriber/imp.rs b/net/aws/src/transcriber/imp.rs
index 63f1c0e3..7af8f50f 100644
--- a/net/aws/src/transcriber/imp.rs
+++ b/net/aws/src/transcriber/imp.rs
@@ -109,8 +109,124 @@ impl TranscriptionSettings {
}
}
+struct TranscriberLoop {
+ imp: glib::subclass::ObjectImplRef<Transcriber>,
+ client: aws_transcribe::Client,
+ settings: TranscriptionSettings,
+ lateness: gst::ClockTime,
+ buffer_rx: mpsc::Receiver<gst::Buffer>,
+ transcript_notif_tx: mpsc::Sender<()>,
+}
+
+impl TranscriberLoop {
+ fn new(
+ imp: &Transcriber,
+ aws_config: &aws_config::SdkConfig,
+ settings: TranscriptionSettings,
+ lateness: gst::ClockTime,
+ buffer_rx: mpsc::Receiver<gst::Buffer>,
+ transcript_notif_tx: mpsc::Sender<()>,
+ ) -> Self {
+ TranscriberLoop {
+ imp: imp.ref_counted(),
+ client: aws_transcribe::Client::new(aws_config),
+ settings,
+ lateness,
+ buffer_rx,
+ transcript_notif_tx,
+ }
+ }
+
+ async fn run(mut self) -> Result<(), gst::ErrorMessage> {
+ // Stream the incoming buffers chunked
+ let chunk_stream = self.buffer_rx.flat_map(move |buffer: gst::Buffer| {
+ async_stream::stream! {
+ let data = buffer.map_readable().unwrap();
+ use aws_transcribe::{model::{AudioEvent, AudioStream}, types::Blob};
+ for chunk in data.chunks(8192) {
+ yield Ok(AudioStream::AudioEvent(AudioEvent::builder().audio_chunk(Blob::new(chunk)).build()));
+ }
+ }
+ });
+
+ let mut transcribe_builder = self
+ .client
+ .start_stream_transcription()
+ .language_code(self.settings.lang_code)
+ .media_sample_rate_hertz(self.settings.sample_rate)
+ .media_encoding(model::MediaEncoding::Pcm)
+ .enable_partial_results_stabilization(true)
+ .partial_results_stability(self.settings.results_stability)
+ .set_vocabulary_name(self.settings.vocabulary)
+ .set_session_id(self.settings.session_id);
+
+ if let Some(vocabulary_filter) = self.settings.vocabulary_filter {
+ transcribe_builder = transcribe_builder
+ .vocabulary_filter_name(vocabulary_filter)
+ .vocabulary_filter_method(self.settings.vocabulary_filter_method);
+ }
+
+ let mut output = transcribe_builder
+ .audio_stream(chunk_stream.into())
+ .send()
+ .await
+ .map_err(|err| {
+ let err = format!("Transcribe ws init error: {err}");
+ gst::error!(CAT, imp: self.imp, "{err}");
+ gst::error_msg!(gst::LibraryError::Init, ["{err}"])
+ })?;
+
+ while let Some(event) = output
+ .transcript_result_stream
+ .recv()
+ .await
+ .map_err(|err| {
+ let err = format!("Transcribe ws stream error: {err}");
+ gst::error!(CAT, imp: self.imp, "{err}");
+ gst::error_msg!(gst::LibraryError::Failed, ["{err}"])
+ })?
+ {
+ if let model::TranscriptResultStream::TranscriptEvent(transcript_evt) = event {
+ let mut enqueued = false;
+
+ if let Some(result) = transcript_evt
+ .transcript
+ .and_then(|transcript| transcript.results)
+ .and_then(|mut results| results.drain(..).next())
+ {
+ gst::trace!(CAT, imp: self.imp, "Received: {result:?}");
+
+ if let Some(alternative) = result
+ .alternatives
+ .and_then(|mut alternatives| alternatives.drain(..).next())
+ {
+ if let Some(items) = alternative.items {
+ enqueued = self.imp.enqueue(items, result.is_partial, self.lateness);
+ }
+ }
+ }
+
+ if enqueued && self.transcript_notif_tx.send(()).await.is_err() {
+ gst::debug!(CAT, imp: self.imp, "Terminated transcript_notif_tx channel");
+ break;
+ }
+ } else {
+ gst::warning!(
+ CAT,
+ imp: self.imp,
+ "Transcribe ws returned unknown event: consider upgrading the SDK"
+ )
+ }
+ }
+
+ gst::debug!(CAT, imp: self.imp, "Exiting ws loop");
+
+ Ok(())
+ }
+}
+
struct State {
- client: Option<aws_transcribe::Client>,
+ aws_config: Option<aws_config::SdkConfig>,
buffer_tx: Option<mpsc::Sender<gst::Buffer>>,
transcript_notif_tx: Option<mpsc::Sender<()>>,
ws_loop_handle: Option<task::JoinHandle<Result<(), gst::ErrorMessage>>>,
@@ -128,7 +244,7 @@ struct State {
impl Default for State {
fn default() -> Self {
Self {
- client: None,
+ aws_config: None,
buffer_tx: None,
transcript_notif_tx: None,
ws_loop_handle: None,
@@ -615,8 +731,8 @@ impl Transcriber {
}
fn ensure_connection(&self) -> Result<(), gst::ErrorMessage> {
- enum ClientStage {
- Ready(aws_transcribe::Client),
+ enum ConfigStatus {
+ Ready(aws_config::SdkConfig),
NotReady {
access_key: Option<String>,
secret_access_key: Option<String>,
@@ -624,7 +740,7 @@ impl Transcriber {
},
}
- let (client_stage, transcription_settings, lateness, transcript_notif_tx);
+ let (config_status, transcription_settings, lateness, transcript_notif_tx);
{
let mut state = self.state.lock().unwrap();
@@ -660,10 +776,10 @@ impl Transcriber {
transcription_settings = TranscriptionSettings::from(&settings, sample_rate);
- client_stage = if let Some(client) = state.client.take() {
- ClientStage::Ready(client)
+ config_status = if let Some(aws_config) = state.aws_config.take() {
+ ConfigStatus::Ready(aws_config)
} else {
- ClientStage::NotReady {
+ ConfigStatus::NotReady {
access_key: settings.access_key.to_owned(),
secret_access_key: settings.secret_access_key.to_owned(),
session_token: settings.session_token.to_owned(),
@@ -671,14 +787,14 @@ impl Transcriber {
};
};
- let client = match client_stage {
- ClientStage::Ready(client) => client,
- ClientStage::NotReady {
+ let aws_config = match config_status {
+ ConfigStatus::Ready(aws_config) => aws_config,
+ ConfigStatus::NotReady {
access_key,
secret_access_key,
session_token,
} => {
- gst::info!(CAT, imp: self, "Connecting...");
+ gst::info!(CAT, imp: self, "Loading aws config...");
let _enter_guard = RUNTIME.enter();
let config_loader = match (access_key, secret_access_key) {
@@ -707,172 +823,31 @@ impl Transcriber {
let config = futures::executor::block_on(config_loader.load());
gst::debug!(CAT, imp: self, "Using region {}", config.region().unwrap());
- aws_transcribe::Client::new(&config)
+ config
}
};
let mut state = self.state.lock().unwrap();
let (buffer_tx, buffer_rx) = mpsc::channel(1);
- let ws_loop_handle = RUNTIME.spawn(self.build_ws_loop_fut(
- client,
+
+ let ws_loop_ctx = TranscriberLoop::new(
+ self,
+ &aws_config,
transcription_settings,
lateness,
buffer_rx,
transcript_notif_tx,
- ));
+ );
+ let ws_loop_handle = RUNTIME.spawn(ws_loop_ctx.run());
+ state.aws_config = Some(aws_config);
state.ws_loop_handle = Some(ws_loop_handle);
state.buffer_tx = Some(buffer_tx);
Ok(())
}
- fn build_ws_loop_fut(
- &self,
- client: aws_transcribe::Client,
- settings: TranscriptionSettings,
- lateness: gst::ClockTime,
- buffer_rx: mpsc::Receiver<gst::Buffer>,
- transcript_notif_tx: mpsc::Sender<()>,
- ) -> impl Future<Output = Result<(), gst::ErrorMessage>> {
- let imp_weak = self.downgrade();
- async move {
- use gst::glib::subclass::ObjectImplWeakRef;
-
- // Guard that restores client & transcript_notif_tx when the ws loop is done
- struct Guard {
- imp_weak: ObjectImplWeakRef<Transcriber>,
- client: Option<aws_transcribe::Client>,
- transcript_notif_tx: Option<mpsc::Sender<()>>,
- }
-
- impl Guard {
- fn client(&self) -> &aws_transcribe::Client {
- self.client.as_ref().unwrap()
- }
-
- fn transcript_notif_tx(&mut self) -> &mut mpsc::Sender<()> {
- self.transcript_notif_tx.as_mut().unwrap()
- }
- }
-
- impl Drop for Guard {
- fn drop(&mut self) {
- if let Some(imp) = self.imp_weak.upgrade() {
- let mut state = imp.state.lock().unwrap();
- state.client = self.client.take();
- state.transcript_notif_tx = self.transcript_notif_tx.take();
- }
- }
- }
-
- let mut guard = Guard {
- imp_weak: imp_weak.clone(),
- client: Some(client),
- transcript_notif_tx: Some(transcript_notif_tx),
- };
-
- // Stream the incoming buffers chunked
- let chunk_stream = buffer_rx.flat_map(move |buffer: gst::Buffer| {
- async_stream::stream! {
- let data = buffer.map_readable().unwrap();
- use aws_transcribe::{model::{AudioEvent, AudioStream}, types::Blob};
- for chunk in data.chunks(8192) {
- yield Ok(AudioStream::AudioEvent(AudioEvent::builder().audio_chunk(Blob::new(chunk)).build()));
- }
- }
- });
-
- let mut transcribe_builder = guard
- .client()
- .start_stream_transcription()
- .language_code(settings.lang_code)
- .media_sample_rate_hertz(settings.sample_rate)
- .media_encoding(model::MediaEncoding::Pcm)
- .enable_partial_results_stabilization(true)
- .partial_results_stability(settings.results_stability)
- .set_vocabulary_name(settings.vocabulary)
- .set_session_id(settings.session_id);
-
- if let Some(vocabulary_filter) = settings.vocabulary_filter {
- transcribe_builder = transcribe_builder
- .vocabulary_filter_name(vocabulary_filter)
- .vocabulary_filter_method(settings.vocabulary_filter_method);
- }
-
- let mut output = transcribe_builder
- .audio_stream(chunk_stream.into())
- .send()
- .await
- .map_err(|err| {
- let err = format!("Transcribe ws init error: {err}");
- if let Some(imp) = imp_weak.upgrade() {
- gst::error!(CAT, imp: imp, "{err}");
- }
- gst::error_msg!(gst::LibraryError::Init, ["{err}"])
- })?;
-
- while let Some(event) = output
- .transcript_result_stream
- .recv()
- .await
- .map_err(|err| {
- let err = format!("Transcribe ws stream error: {err}");
- if let Some(imp) = imp_weak.upgrade() {
- gst::error!(CAT, imp: imp, "{err}");
- }
- gst::error_msg!(gst::LibraryError::Failed, ["{err}"])
- })?
- {
- if let model::TranscriptResultStream::TranscriptEvent(transcript_evt) = event {
- let mut enqueued = false;
-
- if let Some(result) = transcript_evt
- .transcript
- .and_then(|transcript| transcript.results)
- .and_then(|mut results| results.drain(..).next())
- {
- let Some(imp) = imp_weak.upgrade() else { break };
-
- gst::trace!(CAT, imp: imp, "Received: {result:?}");
-
- if let Some(alternative) = result
- .alternatives
- .and_then(|mut alternatives| alternatives.drain(..).next())
- {
- if let Some(items) = alternative.items {
- enqueued = imp.enqueue(items, result.is_partial, lateness);
- }
- }
- }
-
- if enqueued && guard.transcript_notif_tx().send(()).await.is_err() {
- if let Some(imp) = imp_weak.upgrade() {
- gst::debug!(CAT, imp: imp, "Terminated transcript_notif_tx channel");
- }
- break;
- }
- } else if let Some(imp) = imp_weak.upgrade() {
- gst::warning!(
- CAT,
- imp: imp,
- "Transcribe ws returned unknown event: consider upgrading the SDK"
- )
- } else {
- // imp has left the building
- break;
- }
- }
-
- if let Some(imp) = imp_weak.upgrade() {
- gst::debug!(CAT, imp: imp, "Exiting ws loop");
- }
-
- Ok(())
- }
- }
-
fn disconnect(&self) {
let mut state = self.state.lock().unwrap();
gst::info!(CAT, imp: self, "Unpreparing");