diff options
author | TharinduDR <rhtdranasinghe@gmail.com> | 2021-02-11 15:26:41 +0300 |
---|---|---|
committer | TharinduDR <rhtdranasinghe@gmail.com> | 2021-02-11 15:26:41 +0300 |
commit | b1bc1a948f10b62d9a811275bd4430382e10964f (patch) | |
tree | c953d4baf93c52ccfbe33027d30368c3d467a30e | |
parent | 4e571fafd9dd2171647ba0b6bf0d5c490b33d482 (diff) |
055: Adding word level examples
-rw-r--r-- | transquest/app/monotransquest_app.py | 9 | ||||
-rw-r--r-- | transquest/app/util/model_downloader.py | 12 |
2 files changed, 12 insertions, 9 deletions
diff --git a/transquest/app/monotransquest_app.py b/transquest/app/monotransquest_app.py index 7878212..3a40e4c 100644 --- a/transquest/app/monotransquest_app.py +++ b/transquest/app/monotransquest_app.py @@ -15,13 +15,14 @@ class MonoTransQuestApp: self.use_cuda = use_cuda self.cuda_device = cuda_device + MODEL_CONFIG = { - "monotransquest-da-si_en": ("xlmroberta", "1-UXvna_RGnb6_TTRr4vSGCqA5yl0SYn9"), - "monotransquest-da-ro_en": ("xlmroberta", "1-aeDbR_ftqsTslFJbNybebj5MAhPfIw8") + "monotransquest-da-si_en": ("xlmroberta", "1-UXvna_RGnb6_TTRr4vSGCqA5yl0SYn9", 3.8), + "monotransquest-da-ro_en": ("xlmroberta", "1-aeDbR_ftqsTslFJbNybebj5MAhPfIw8", 3.8) } if model_name_or_path in MODEL_CONFIG: - self.trained_model_type, self.drive_id = MODEL_CONFIG[model_name_or_path] + self.trained_model_type, self.drive_id, self.size = MODEL_CONFIG[model_name_or_path] try: from torch.hub import _get_torch_home @@ -38,7 +39,7 @@ class MonoTransQuestApp: gdd.download_file_from_google_drive(file_id=self.drive_id, dest_path=os.path.join(self.model_path, "model.zip"), - showsize=True, unzip=True, overwrite=True) + showsize=True, unzip=True, overwrite=True, size=self.size) self.model = MonoTransQuestModel(self.trained_model_type, self.model_path, use_cuda=self.use_cuda, cuda_device=self.cuda_device) diff --git a/transquest/app/util/model_downloader.py b/transquest/app/util/model_downloader.py index 1b89faf..e064a26 100644 --- a/transquest/app/util/model_downloader.py +++ b/transquest/app/util/model_downloader.py @@ -17,12 +17,12 @@ class GoogleDriveDownloader: """ Minimal class to download shared files from Google Drive. """ - + MODEL_SIZE = 3.8 CHUNK_SIZE = 32768 DOWNLOAD_URL = 'https://docs.google.com/uc?export=download' @staticmethod - def download_file_from_google_drive(file_id, dest_path, overwrite=False, unzip=False, showsize=False): + def download_file_from_google_drive(file_id, dest_path, overwrite=False, unzip=False, showsize=False, size=MODEL_SIZE): """ Downloads a shared file from google drive into a given folder. Optionally unzips it. @@ -42,6 +42,8 @@ class GoogleDriveDownloader: If the file is not a zip file, ignores it. showsize: bool optional, if True print the current download size. + size:float + optional, if given it shows the progress of the download Returns ------- None @@ -69,7 +71,7 @@ class GoogleDriveDownloader: logger.info("\n") # Skip to the next line current_download_size = [0] - GoogleDriveDownloader._save_response_content(response, dest_path, showsize, current_download_size) + GoogleDriveDownloader._save_response_content(response, dest_path, showsize, current_download_size, size) logger.info('Done.') if unzip: @@ -90,13 +92,13 @@ class GoogleDriveDownloader: return None @staticmethod - def _save_response_content(response, destination, showsize, current_size): + def _save_response_content(response, destination, showsize, current_size, total_size): with open(destination, 'wb') as f: for chunk in response.iter_content(GoogleDriveDownloader.CHUNK_SIZE): if chunk: # filter out keep-alive new chunks f.write(chunk) if showsize: - print('\r' + GoogleDriveDownloader.sizeof_fmt(current_size[0]), end=' ') + print('\r' + GoogleDriveDownloader.sizeof_fmt(current_size[0]/total_size), end=' ') stdout.flush() current_size[0] += GoogleDriveDownloader.CHUNK_SIZE |