Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/TharinduDR/TransQuest.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTharinduDR <rhtdranasinghe@gmail.com>2021-02-11 15:26:41 +0300
committerTharinduDR <rhtdranasinghe@gmail.com>2021-02-11 15:26:41 +0300
commitb1bc1a948f10b62d9a811275bd4430382e10964f (patch)
treec953d4baf93c52ccfbe33027d30368c3d467a30e
parent4e571fafd9dd2171647ba0b6bf0d5c490b33d482 (diff)
055: Adding word level examples
-rw-r--r--transquest/app/monotransquest_app.py9
-rw-r--r--transquest/app/util/model_downloader.py12
2 files changed, 12 insertions, 9 deletions
diff --git a/transquest/app/monotransquest_app.py b/transquest/app/monotransquest_app.py
index 7878212..3a40e4c 100644
--- a/transquest/app/monotransquest_app.py
+++ b/transquest/app/monotransquest_app.py
@@ -15,13 +15,14 @@ class MonoTransQuestApp:
self.use_cuda = use_cuda
self.cuda_device = cuda_device
+
MODEL_CONFIG = {
- "monotransquest-da-si_en": ("xlmroberta", "1-UXvna_RGnb6_TTRr4vSGCqA5yl0SYn9"),
- "monotransquest-da-ro_en": ("xlmroberta", "1-aeDbR_ftqsTslFJbNybebj5MAhPfIw8")
+ "monotransquest-da-si_en": ("xlmroberta", "1-UXvna_RGnb6_TTRr4vSGCqA5yl0SYn9", 3.8),
+ "monotransquest-da-ro_en": ("xlmroberta", "1-aeDbR_ftqsTslFJbNybebj5MAhPfIw8", 3.8)
}
if model_name_or_path in MODEL_CONFIG:
- self.trained_model_type, self.drive_id = MODEL_CONFIG[model_name_or_path]
+ self.trained_model_type, self.drive_id, self.size = MODEL_CONFIG[model_name_or_path]
try:
from torch.hub import _get_torch_home
@@ -38,7 +39,7 @@ class MonoTransQuestApp:
gdd.download_file_from_google_drive(file_id=self.drive_id,
dest_path=os.path.join(self.model_path, "model.zip"),
- showsize=True, unzip=True, overwrite=True)
+ showsize=True, unzip=True, overwrite=True, size=self.size)
self.model = MonoTransQuestModel(self.trained_model_type, self.model_path, use_cuda=self.use_cuda,
cuda_device=self.cuda_device)
diff --git a/transquest/app/util/model_downloader.py b/transquest/app/util/model_downloader.py
index 1b89faf..e064a26 100644
--- a/transquest/app/util/model_downloader.py
+++ b/transquest/app/util/model_downloader.py
@@ -17,12 +17,12 @@ class GoogleDriveDownloader:
"""
Minimal class to download shared files from Google Drive.
"""
-
+ MODEL_SIZE = 3.8
CHUNK_SIZE = 32768
DOWNLOAD_URL = 'https://docs.google.com/uc?export=download'
@staticmethod
- def download_file_from_google_drive(file_id, dest_path, overwrite=False, unzip=False, showsize=False):
+ def download_file_from_google_drive(file_id, dest_path, overwrite=False, unzip=False, showsize=False, size=MODEL_SIZE):
"""
Downloads a shared file from google drive into a given folder.
Optionally unzips it.
@@ -42,6 +42,8 @@ class GoogleDriveDownloader:
If the file is not a zip file, ignores it.
showsize: bool
optional, if True print the current download size.
+ size:float
+ optional, if given it shows the progress of the download
Returns
-------
None
@@ -69,7 +71,7 @@ class GoogleDriveDownloader:
logger.info("\n") # Skip to the next line
current_download_size = [0]
- GoogleDriveDownloader._save_response_content(response, dest_path, showsize, current_download_size)
+ GoogleDriveDownloader._save_response_content(response, dest_path, showsize, current_download_size, size)
logger.info('Done.')
if unzip:
@@ -90,13 +92,13 @@ class GoogleDriveDownloader:
return None
@staticmethod
- def _save_response_content(response, destination, showsize, current_size):
+ def _save_response_content(response, destination, showsize, current_size, total_size):
with open(destination, 'wb') as f:
for chunk in response.iter_content(GoogleDriveDownloader.CHUNK_SIZE):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
if showsize:
- print('\r' + GoogleDriveDownloader.sizeof_fmt(current_size[0]), end=' ')
+ print('\r' + GoogleDriveDownloader.sizeof_fmt(current_size[0]/total_size), end=' ')
stdout.flush()
current_size[0] += GoogleDriveDownloader.CHUNK_SIZE