diff options
author | Jelle Spijker <spijker.jelle@gmail.com> | 2021-11-22 15:13:40 +0300 |
---|---|---|
committer | Jelle Spijker <spijker.jelle@gmail.com> | 2021-11-22 15:13:40 +0300 |
commit | af56acb820dae631bf6ea9f42e5b3768b766cdce (patch) | |
tree | c0615e55a7acbf0d9f01579d75c4b7f266740d4e | |
parent | 5511fecf2fcd58ce501f235203930b2a36f4d801 (diff) | |
parent | 9dd251975d2fbf339021976cf9f202c8a545c13c (diff) |
Merge branch 'CURA-8539_oauth_via_httprequestmanager' into conan/CURA-8539_oauth_via_httprequestmanagerconan/CURA-8539_oauth_via_httprequestmanager
-rw-r--r-- | cura/API/Account.py | 43 | ||||
-rw-r--r-- | cura/OAuth2/AuthorizationHelpers.py | 175 | ||||
-rw-r--r-- | cura/OAuth2/AuthorizationRequestHandler.py | 19 | ||||
-rw-r--r-- | cura/OAuth2/AuthorizationService.py | 168 | ||||
-rw-r--r-- | tests/TestOAuth2.py | 8 |
5 files changed, 234 insertions, 179 deletions
diff --git a/cura/API/Account.py b/cura/API/Account.py index 2d4b204333..bf8a883c1a 100644 --- a/cura/API/Account.py +++ b/cura/API/Account.py @@ -1,15 +1,15 @@ -# Copyright (c) 2018 Ultimaker B.V. +# Copyright (c) 2021 Ultimaker B.V. # Cura is released under the terms of the LGPLv3 or higher. -from datetime import datetime -from typing import Any, Optional, Dict, TYPE_CHECKING, Callable +from datetime import datetime from PyQt5.QtCore import QObject, pyqtSignal, pyqtSlot, pyqtProperty, QTimer, Q_ENUMS +from typing import Any, Optional, Dict, TYPE_CHECKING, Callable from UM.Logger import Logger from UM.Message import Message from UM.i18n import i18nCatalog from cura.OAuth2.AuthorizationService import AuthorizationService -from cura.OAuth2.Models import OAuth2Settings +from cura.OAuth2.Models import OAuth2Settings, UserProfile from cura.UltimakerCloud import UltimakerCloudConstants if TYPE_CHECKING: @@ -46,6 +46,9 @@ class Account(QObject): loginStateChanged = pyqtSignal(bool) """Signal emitted when user logged in or out""" + userProfileChanged = pyqtSignal() + """Signal emitted when new account information is available.""" + additionalRightsChanged = pyqtSignal("QVariantMap") """Signal emitted when a users additional rights change""" @@ -73,6 +76,7 @@ class Account(QObject): self._error_message = None # type: Optional[Message] self._logged_in = False + self._user_profile = None self._additional_rights: Dict[str, Any] = {} self._sync_state = SyncState.IDLE self._manual_sync_enabled = False @@ -196,12 +200,17 @@ class Account(QObject): self._logged_in = logged_in self.loginStateChanged.emit(logged_in) if logged_in: + self._authorization_service.getUserProfile(self._onProfileChanged) self._setManualSyncEnabled(False) self._sync() else: if self._update_timer.isActive(): self._update_timer.stop() + def _onProfileChanged(self, profile: UserProfile): + self._user_profile = profile + self.userProfileChanged.emit() + def _sync(self) -> None: """Signals all sync services to start syncing @@ -243,32 +252,28 @@ class Account(QObject): return self._authorization_service.startAuthorizationFlow(force_logout_before_login) - @pyqtProperty(str, notify=loginStateChanged) + @pyqtProperty(str, notify = userProfileChanged) def userName(self): - user_profile = self._authorization_service.getUserProfile() - if not user_profile: - return None - return user_profile.username + if not self._user_profile: + return "" + return self._user_profile.username - @pyqtProperty(str, notify = loginStateChanged) + @pyqtProperty(str, notify = userProfileChanged) def profileImageUrl(self): - user_profile = self._authorization_service.getUserProfile() - if not user_profile: - return None - return user_profile.profile_image_url + if not self._user_profile: + return "" + return self._user_profile.profile_image_url @pyqtProperty(str, notify=accessTokenChanged) def accessToken(self) -> Optional[str]: return self._authorization_service.getAccessToken() - @pyqtProperty("QVariantMap", notify = loginStateChanged) + @pyqtProperty("QVariantMap", notify = userProfileChanged) def userProfile(self) -> Optional[Dict[str, Optional[str]]]: """None if no user is logged in otherwise the logged in user as a dict containing containing user_id, username and profile_image_url """ - - user_profile = self._authorization_service.getUserProfile() - if not user_profile: + if not self._user_profile: return None - return user_profile.__dict__ + return self._user_profile.__dict__ @pyqtProperty(str, notify=lastSyncDateTimeChanged) def lastSyncDateTime(self) -> str: diff --git a/cura/OAuth2/AuthorizationHelpers.py b/cura/OAuth2/AuthorizationHelpers.py index d6f4980fe4..d84da46c5f 100644 --- a/cura/OAuth2/AuthorizationHelpers.py +++ b/cura/OAuth2/AuthorizationHelpers.py @@ -1,18 +1,19 @@ # Copyright (c) 2021 Ultimaker B.V. # Cura is released under the terms of the LGPLv3 or higher. +from base64 import b64encode from datetime import datetime -import json -import secrets from hashlib import sha512 -from base64 import b64encode -from typing import Optional -import requests +from PyQt5.QtNetwork import QNetworkReply +import secrets +from typing import Callable, Optional +import urllib.parse +from cura.OAuth2.Models import AuthenticationResponse, UserProfile, OAuth2Settings from UM.i18n import i18nCatalog from UM.Logger import Logger +from UM.TaskManagement.HttpRequestManager import HttpRequestManager # To download log-in tokens. -from cura.OAuth2.Models import AuthenticationResponse, UserProfile, OAuth2Settings catalog = i18nCatalog("cura") TOKEN_TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S" @@ -30,14 +31,13 @@ class AuthorizationHelpers: return self._settings - def getAccessTokenUsingAuthorizationCode(self, authorization_code: str, verification_code: str) -> "AuthenticationResponse": - """Request the access token from the authorization server. - + def getAccessTokenUsingAuthorizationCode(self, authorization_code: str, verification_code: str, callback: Callable[[AuthenticationResponse], None]) -> None: + """ + Request the access token from the authorization server. :param authorization_code: The authorization code from the 1st step. :param verification_code: The verification code needed for the PKCE extension. - :return: An AuthenticationResponse object. + :param callback: Once the token has been obtained, this function will be called with the response. """ - data = { "client_id": self._settings.CLIENT_ID if self._settings.CLIENT_ID is not None else "", "redirect_uri": self._settings.CALLBACK_URL if self._settings.CALLBACK_URL is not None else "", @@ -46,18 +46,20 @@ class AuthorizationHelpers: "code_verifier": verification_code, "scope": self._settings.CLIENT_SCOPES if self._settings.CLIENT_SCOPES is not None else "", } - try: - return self.parseTokenResponse(requests.post(self._token_url, data = data)) # type: ignore - except requests.exceptions.ConnectionError as connection_error: - return AuthenticationResponse(success = False, err_message = f"Unable to connect to remote server: {connection_error}") - - def getAccessTokenUsingRefreshToken(self, refresh_token: str) -> "AuthenticationResponse": - """Request the access token from the authorization server using a refresh token. + headers = {"Content-type": "application/x-www-form-urlencoded"} + HttpRequestManager.getInstance().post( + self._token_url, + data = urllib.parse.urlencode(data).encode("UTF-8"), + headers_dict = headers, + callback = lambda response: self.parseTokenResponse(response, callback) + ) - :param refresh_token: - :return: An AuthenticationResponse object. + def getAccessTokenUsingRefreshToken(self, refresh_token: str, callback: Callable[[AuthenticationResponse], None]) -> None: + """ + Request the access token from the authorization server using a refresh token. + :param refresh_token: A long-lived token used to refresh the authentication token. + :param callback: Once the token has been obtained, this function will be called with the response. """ - Logger.log("d", "Refreshing the access token for [%s]", self._settings.OAUTH_SERVER_URL) data = { "client_id": self._settings.CLIENT_ID if self._settings.CLIENT_ID is not None else "", @@ -66,75 +68,94 @@ class AuthorizationHelpers: "refresh_token": refresh_token, "scope": self._settings.CLIENT_SCOPES if self._settings.CLIENT_SCOPES is not None else "", } - try: - return self.parseTokenResponse(requests.post(self._token_url, data = data)) # type: ignore - except requests.exceptions.ConnectionError: - return AuthenticationResponse(success = False, err_message = "Unable to connect to remote server") - except OSError as e: - return AuthenticationResponse(success = False, err_message = "Operating system is unable to set up a secure connection: {err}".format(err = str(e))) + headers = {"Content-type": "application/x-www-form-urlencoded"} + HttpRequestManager.getInstance().post( + self._token_url, + data = urllib.parse.urlencode(data).encode("UTF-8"), + headers_dict = headers, + callback = lambda response: self.parseTokenResponse(response, callback) + ) - @staticmethod - def parseTokenResponse(token_response: requests.models.Response) -> "AuthenticationResponse": + def parseTokenResponse(self, token_response: QNetworkReply, callback: Callable[[AuthenticationResponse], None]) -> None: """Parse the token response from the authorization server into an AuthenticationResponse object. :param token_response: The JSON string data response from the authorization server. :return: An AuthenticationResponse object. """ - - token_data = None - - try: - token_data = json.loads(token_response.text) - except ValueError: - Logger.log("w", "Could not parse token response data: %s", token_response.text) - + token_data = HttpRequestManager.readJSON(token_response) if not token_data: - return AuthenticationResponse(success = False, err_message = catalog.i18nc("@message", "Could not read response.")) - - if token_response.status_code not in (200, 201): - return AuthenticationResponse(success = False, err_message = token_data["error_description"]) - - return AuthenticationResponse(success=True, - token_type=token_data["token_type"], - access_token=token_data["access_token"], - refresh_token=token_data["refresh_token"], - expires_in=token_data["expires_in"], - scope=token_data["scope"], - received_at=datetime.now().strftime(TOKEN_TIMESTAMP_FORMAT)) - - def parseJWT(self, access_token: str) -> Optional["UserProfile"]: + callback(AuthenticationResponse(success = False, err_message = catalog.i18nc("@message", "Could not read response."))) + return + + if token_response.error() != QNetworkReply.NetworkError.NoError: + callback(AuthenticationResponse(success = False, err_message = token_data["error_description"])) + return + + callback(AuthenticationResponse(success = True, + token_type = token_data["token_type"], + access_token = token_data["access_token"], + refresh_token = token_data["refresh_token"], + expires_in = token_data["expires_in"], + scope = token_data["scope"], + received_at = datetime.now().strftime(TOKEN_TIMESTAMP_FORMAT))) + return + + def checkToken(self, access_token: str, success_callback: Optional[Callable[[UserProfile], None]] = None, failed_callback: Optional[Callable[[], None]] = None) -> None: """Calls the authentication API endpoint to get the token data. + The API is called asynchronously. When a response is given, the callback is called with the user's profile. :param access_token: The encoded JWT token. - :return: Dict containing some profile data. + :param success_callback: When a response is given, this function will be called with a user profile. If None, + there will not be a callback. + :param failed_callback: When the request failed or the response didn't parse, this function will be called. """ - - try: - check_token_url = "{}/check-token".format(self._settings.OAUTH_SERVER_URL) - Logger.log("d", "Checking the access token for [%s]", check_token_url) - token_request = requests.get(check_token_url, headers = { - "Authorization": "Bearer {}".format(access_token) - }) - except (requests.exceptions.ConnectionError, requests.exceptions.Timeout): - # Connection was suddenly dropped. Nothing we can do about that. - Logger.logException("w", "Something failed while attempting to parse the JWT token") - return None - if token_request.status_code not in (200, 201): - Logger.log("w", "Could not retrieve token data from auth server: %s", token_request.text) - return None - user_data = token_request.json().get("data") - if not user_data or not isinstance(user_data, dict): - Logger.log("w", "Could not parse user data from token: %s", user_data) - return None - - return UserProfile( - user_id = user_data["user_id"], - username = user_data["username"], - profile_image_url = user_data.get("profile_image_url", ""), - organization_id = user_data.get("organization", {}).get("organization_id"), - subscriptions = user_data.get("subscriptions", []) + check_token_url = "{}/check-token".format(self._settings.OAUTH_SERVER_URL) + Logger.log("d", "Checking the access token for [%s]", check_token_url) + headers = { + "Authorization": f"Bearer {access_token}" + } + HttpRequestManager.getInstance().get( + check_token_url, + headers_dict = headers, + callback = lambda reply: self._parseUserProfile(reply, success_callback, failed_callback), + error_callback = lambda _, _2: failed_callback() if failed_callback is not None else None ) + def _parseUserProfile(self, reply: QNetworkReply, success_callback: Optional[Callable[[UserProfile], None]], failed_callback: Optional[Callable[[], None]] = None) -> None: + """ + Parses the user profile from a reply to /check-token. + + If the response is valid, the callback will be called to return the user profile to the caller. + :param reply: A network reply to a request to the /check-token URL. + :param success_callback: A function to call once a user profile was successfully obtained. + :param failed_callback: A function to call if parsing the profile failed. + """ + if reply.error() != QNetworkReply.NetworkError.NoError: + Logger.warning(f"Could not access account information. QNetworkError {reply.errorString()}") + failed_callback() + return + + profile_data = HttpRequestManager.getInstance().readJSON(reply) + if profile_data is None or "data" not in profile_data: + Logger.warning("Could not parse user data from token.") + failed_callback() + return + profile_data = profile_data["data"] + + required_fields = {"user_id", "username"} + if "user_id" not in profile_data or "username" not in profile_data: + Logger.warning(f"User data missing required field(s): {required_fields - set(profile_data.keys())}") + failed_callback() + return + + success_callback(UserProfile( + user_id = profile_data["user_id"], + username = profile_data["username"], + profile_image_url = profile_data.get("profile_image_url", ""), + organization_id = profile_data.get("organization", {}).get("organization_id"), + subscriptions = profile_data.get("subscriptions", []) + )) + @staticmethod def generateVerificationCode(code_length: int = 32) -> str: """Generate a verification code of arbitrary length. diff --git a/cura/OAuth2/AuthorizationRequestHandler.py b/cura/OAuth2/AuthorizationRequestHandler.py index c7ce9b6faf..ff01969c50 100644 --- a/cura/OAuth2/AuthorizationRequestHandler.py +++ b/cura/OAuth2/AuthorizationRequestHandler.py @@ -2,6 +2,7 @@ # Cura is released under the terms of the LGPLv3 or higher. from http.server import BaseHTTPRequestHandler +from threading import Lock # To turn an asynchronous call synchronous. from typing import Optional, Callable, Tuple, Dict, Any, List, TYPE_CHECKING from urllib.parse import parse_qs, urlparse @@ -70,13 +71,23 @@ class AuthorizationRequestHandler(BaseHTTPRequestHandler): if state != self.state: token_response = AuthenticationResponse( success = False, - err_message=catalog.i18nc("@message", - "The provided state is not correct.") + err_message = catalog.i18nc("@message", "The provided state is not correct.") ) elif code and self.authorization_helpers is not None and self.verification_code is not None: + token_response = AuthenticationResponse( + success = False, + err_message = catalog.i18nc("@message", "Timeout when authenticating with the account server.") + ) # If the code was returned we get the access token. - token_response = self.authorization_helpers.getAccessTokenUsingAuthorizationCode( - code, self.verification_code) + lock = Lock() + lock.acquire() + + def callback(response: AuthenticationResponse) -> None: + nonlocal token_response + token_response = response + lock.release() + self.authorization_helpers.getAccessTokenUsingAuthorizationCode(code, self.verification_code, callback) + lock.acquire(timeout = 60) # Block thread until request is completed (which releases the lock). If not acquired, the timeout message stays. elif self._queryGet(query, "error_code") == "user_denied": # Otherwise we show an error message (probably the user clicked "Deny" in the auth dialog). diff --git a/cura/OAuth2/AuthorizationService.py b/cura/OAuth2/AuthorizationService.py index 291845fd78..fb2ba40c71 100644 --- a/cura/OAuth2/AuthorizationService.py +++ b/cura/OAuth2/AuthorizationService.py @@ -3,10 +3,9 @@ import json from datetime import datetime, timedelta -from typing import Optional, TYPE_CHECKING, Dict +from typing import Callable, Dict, Optional, TYPE_CHECKING from urllib.parse import urlencode, quote_plus -import requests.exceptions from PyQt5.QtCore import QUrl from PyQt5.QtGui import QDesktopServices @@ -62,69 +61,80 @@ class AuthorizationService: if self._preferences: self._preferences.addPreference(self._settings.AUTH_DATA_PREFERENCE_KEY, "{}") - def getUserProfile(self) -> Optional["UserProfile"]: - """Get the user profile as obtained from the JWT (JSON Web Token). - - If the JWT is not yet parsed, calling this will take care of that. + def getUserProfile(self, callback: Callable[[Optional["UserProfile"]], None] = None) -> None: + """ + Get the user profile as obtained from the JWT (JSON Web Token). - :return: UserProfile if a user is logged in, None otherwise. + If the JWT is not yet checked and parsed, calling this will take care of that. + :param callback: Once the user profile is obtained, this function will be called with the given user profile. If + the profile fails to be obtained, this function will be called with None. See also: :py:method:`cura.OAuth2.AuthorizationService.AuthorizationService._parseJWT` """ + if self._user_profile: + # We already obtained the profile. No need to make another request for it. + if callback is not None: + callback(self._user_profile) + return - if not self._user_profile: - # If no user profile was stored locally, we try to get it from JWT. - try: - self._user_profile = self._parseJWT() - except requests.exceptions.ConnectionError: - # Unable to get connection, can't login. - Logger.logException("w", "Unable to validate user data with the remote server.") - return None - - if not self._user_profile and self._auth_data: - # If there is still no user profile from the JWT, we have to log in again. - Logger.log("w", "The user profile could not be loaded. The user must log in again!") - self.deleteAuthData() - return None - - return self._user_profile - - def _parseJWT(self) -> Optional["UserProfile"]: - """Tries to parse the JWT (JSON Web Token) data, which it does if all the needed data is there. - - :return: UserProfile if it was able to parse, None otherwise. + # If no user profile was stored locally, we try to get it from JWT. + def store_profile(profile: Optional["UserProfile"]): + if profile is not None: + self._user_profile = profile + if callback is not None: + callback(profile) + elif self._auth_data: + # If there is no user profile from the JWT, we have to log in again. + Logger.warning("The user profile could not be loaded. The user must log in again!") + self.deleteAuthData() + if callback is not None: + callback(None) + else: + if callback is not None: + callback(None) + + self._parseJWT(callback = store_profile) + + def _parseJWT(self, callback: Callable[[Optional["UserProfile"]], None]) -> None: + """ + Tries to parse the JWT (JSON Web Token) data, which it does if all the needed data is there. + :param callback: A function to call asynchronously once the user profile has been obtained. It will be called + with `None` if it failed to obtain a user profile. """ if not self._auth_data or self._auth_data.access_token is None: # If no auth data exists, we should always log in again. - Logger.log("d", "There was no auth data or access token") - return None - - try: - user_data = self._auth_helpers.parseJWT(self._auth_data.access_token) - except AttributeError: - # THis might seem a bit double, but we get crash reports about this (CURA-2N2 in sentry) - Logger.log("d", "There was no auth data or access token") - return None + Logger.debug("There was no auth data or access token") + callback(None) + return - if user_data: - # If the profile was found, we return it immediately. - return user_data - # The JWT was expired or invalid and we should request a new one. - if self._auth_data.refresh_token is None: - Logger.log("w", "There was no refresh token in the auth data.") - return None - self._auth_data = self._auth_helpers.getAccessTokenUsingRefreshToken(self._auth_data.refresh_token) - if not self._auth_data or self._auth_data.access_token is None: - Logger.log("w", "Unable to use the refresh token to get a new access token.") - # The token could not be refreshed using the refresh token. We should login again. - return None - # Ensure it gets stored as otherwise we only have it in memory. The stored refresh token has been deleted - # from the server already. Do not store the auth_data if we could not get new auth_data (eg due to a - # network error), since this would cause an infinite loop trying to get new auth-data - if self._auth_data.success: - self._storeAuthData(self._auth_data) - return self._auth_helpers.parseJWT(self._auth_data.access_token) + # When we checked the token we may get a user profile. This callback checks if that is a valid one and tries to refresh the token if it's not. + def check_user_profile(user_profile): + if user_profile: + # If the profile was found, we call it back immediately. + callback(user_profile) + return + # The JWT was expired or invalid and we should request a new one. + if self._auth_data.refresh_token is None: + Logger.warning("There was no refresh token in the auth data.") + callback(None) + return + + def process_auth_data(auth_data: AuthenticationResponse): + if auth_data.access_token is None: + Logger.warning("Unable to use the refresh token to get a new access token.") + callback(None) + return + # Ensure it gets stored as otherwise we only have it in memory. The stored refresh token has been + # deleted from the server already. Do not store the auth_data if we could not get new auth_data (e.g. + # due to a network error), since this would cause an infinite loop trying to get new auth-data. + if auth_data.success: + self._storeAuthData(auth_data) + self._auth_helpers.checkToken(auth_data.access_token, callback, lambda: callback(None)) + + self._auth_helpers.getAccessTokenUsingRefreshToken(self._auth_data.refresh_token, process_auth_data) + + self._auth_helpers.checkToken(self._auth_data.access_token, check_user_profile) def getAccessToken(self) -> Optional[str]: """Get the access token as provided by the response data.""" @@ -149,13 +159,16 @@ class AuthorizationService: if self._auth_data is None or self._auth_data.refresh_token is None: Logger.log("w", "Unable to refresh access token, since there is no refresh token.") return - response = self._auth_helpers.getAccessTokenUsingRefreshToken(self._auth_data.refresh_token) - if response.success: - self._storeAuthData(response) - self.onAuthStateChanged.emit(logged_in = True) - else: - Logger.log("w", "Failed to get a new access token from the server.") - self.onAuthStateChanged.emit(logged_in = False) + + def process_auth_data(response: AuthenticationResponse): + if response.success: + self._storeAuthData(response) + self.onAuthStateChanged.emit(logged_in = True) + else: + Logger.warning("Failed to get a new access token from the server.") + self.onAuthStateChanged.emit(logged_in = False) + + self._auth_helpers.getAccessTokenUsingRefreshToken(self._auth_data.refresh_token, process_auth_data) def deleteAuthData(self) -> None: """Delete the authentication data that we have stored locally (eg; logout)""" @@ -244,21 +257,22 @@ class AuthorizationService: preferences_data = json.loads(self._preferences.getValue(self._settings.AUTH_DATA_PREFERENCE_KEY)) if preferences_data: self._auth_data = AuthenticationResponse(**preferences_data) + # Also check if we can actually get the user profile information. - user_profile = self.getUserProfile() - if user_profile is not None: - self.onAuthStateChanged.emit(logged_in = True) - Logger.log("d", "Auth data was successfully loaded") - else: - if self._unable_to_get_data_message is not None: - self._unable_to_get_data_message.hide() - - self._unable_to_get_data_message = Message(i18n_catalog.i18nc("@info", - "Unable to reach the Ultimaker account server."), - title = i18n_catalog.i18nc("@info:title", "Warning"), - message_type = Message.MessageType.ERROR) - Logger.log("w", "Unable to load auth data from preferences") - self._unable_to_get_data_message.show() + def callback(profile: Optional["UserProfile"]): + if profile is not None: + self.onAuthStateChanged.emit(logged_in = True) + Logger.debug("Auth data was successfully loaded") + else: + if self._unable_to_get_data_message is not None: + self._unable_to_get_data_message.show() + else: + self._unable_to_get_data_message = Message(i18n_catalog.i18nc("@info", "Unable to reach the Ultimaker account server."), + title = i18n_catalog.i18nc("@info:title", "Log-in failed"), + message_type = Message.MessageType.ERROR) + Logger.warning("Unable to get user profile using auth data from preferences.") + self._unable_to_get_data_message.show() + self.getUserProfile(callback) except (ValueError, TypeError): Logger.logException("w", "Could not load auth data from preferences") @@ -272,7 +286,7 @@ class AuthorizationService: self._auth_data = auth_data if auth_data: - self._user_profile = self.getUserProfile() + self.getUserProfile() self._preferences.setValue(self._settings.AUTH_DATA_PREFERENCE_KEY, json.dumps(auth_data.dump())) else: Logger.log("d", "Clearing the user profile") diff --git a/tests/TestOAuth2.py b/tests/TestOAuth2.py index 2c039b296a..24cfe50921 100644 --- a/tests/TestOAuth2.py +++ b/tests/TestOAuth2.py @@ -1,5 +1,5 @@ from datetime import datetime -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, Mock, patch import requests @@ -53,7 +53,11 @@ def test_cleanAuthService() -> None: # Ensure that when setting up an AuthorizationService, no data is set. authorization_service = AuthorizationService(OAUTH_SETTINGS, Preferences()) authorization_service.initialize() - assert authorization_service.getUserProfile() is None + + mock_callback = Mock() + authorization_service.getUserProfile(mock_callback) + mock_callback.assert_called_once_with(None) + assert authorization_service.getAccessToken() is None |