diff options
author | Adrien Ferrand <adferrand@users.noreply.github.com> | 2022-01-13 03:36:51 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-01-13 03:36:51 +0300 |
commit | 16aad35d31a887dab157f9d4f5e0fe9218d06064 (patch) | |
tree | 067093f1e0523b9de843afad8aca718a3570a172 /certbot-nginx | |
parent | 30b066f08260b73fc26256b5484a180468b9d0a6 (diff) |
Fully type certbot-nginx module (#9124)
* Work in progress
* Fix type
* Work in progress
* Work in progress
* Work in progress
* Work in progress
* Work in progress
* Oups.
* Fix typing in UnspacedList
* Fix logic
* Finish typing
* List certbot-nginx as fully typed in tox
* Fix lint
* Fix checks
* Organize imports
* Fix typing for Python 3.6
* Fix checks
* Fix lint
* Update certbot-nginx/certbot_nginx/_internal/configurator.py
Co-authored-by: alexzorin <alex@zor.io>
* Update certbot-nginx/certbot_nginx/_internal/configurator.py
Co-authored-by: alexzorin <alex@zor.io>
* Fix signature of deploy_cert regarding the installer interface
* Update certbot-nginx/certbot_nginx/_internal/obj.py
Co-authored-by: alexzorin <alex@zor.io>
* Fix types
* Update certbot-nginx/certbot_nginx/_internal/parser.py
Co-authored-by: alexzorin <alex@zor.io>
* Precise type
* Precise _coerce possible inputs/outputs
* Fix type
* Update certbot-nginx/certbot_nginx/_internal/http_01.py
Co-authored-by: ohemorange <ebportnoy@gmail.com>
* Fix type
* Remove an undesirable implementation.
* Fix type
Co-authored-by: alexzorin <alex@zor.io>
Co-authored-by: ohemorange <ebportnoy@gmail.com>
Diffstat (limited to 'certbot-nginx')
-rw-r--r-- | certbot-nginx/certbot_nginx/_internal/configurator.py | 190 | ||||
-rw-r--r-- | certbot-nginx/certbot_nginx/_internal/constants.py | 6 | ||||
-rw-r--r-- | certbot-nginx/certbot_nginx/_internal/display_ops.py | 9 | ||||
-rw-r--r-- | certbot-nginx/certbot_nginx/_internal/http_01.py | 48 | ||||
-rw-r--r-- | certbot-nginx/certbot_nginx/_internal/nginxparser.py | 148 | ||||
-rw-r--r-- | certbot-nginx/certbot_nginx/_internal/obj.py | 49 | ||||
-rw-r--r-- | certbot-nginx/certbot_nginx/_internal/parser.py | 184 | ||||
-rw-r--r-- | certbot-nginx/certbot_nginx/_internal/parser_obj.py | 104 |
8 files changed, 441 insertions, 297 deletions
diff --git a/certbot-nginx/certbot_nginx/_internal/configurator.py b/certbot-nginx/certbot_nginx/_internal/configurator.py index 46ec7d57e..fb819f194 100644 --- a/certbot-nginx/certbot_nginx/_internal/configurator.py +++ b/certbot-nginx/certbot_nginx/_internal/configurator.py @@ -6,30 +6,38 @@ import socket import subprocess import tempfile import time +from typing import Any +from typing import Callable from typing import Dict +from typing import Iterable from typing import List +from typing import Mapping from typing import Optional +from typing import Sequence from typing import Set from typing import Text from typing import Tuple +from typing import Type +from typing import Union +from certbot_nginx._internal import constants +from certbot_nginx._internal import display_ops +from certbot_nginx._internal import http_01 +from certbot_nginx._internal import nginxparser +from certbot_nginx._internal import obj +from certbot_nginx._internal import parser import OpenSSL import pkg_resources from acme import challenges from acme import crypto_util as acme_crypto_util +from certbot import achallenges from certbot import crypto_util from certbot import errors from certbot import util -from certbot.display import util as display_util from certbot.compat import os +from certbot.display import util as display_util from certbot.plugins import common -from certbot_nginx._internal import constants -from certbot_nginx._internal import display_ops -from certbot_nginx._internal import http_01 -from certbot_nginx._internal import nginxparser -from certbot_nginx._internal import obj -from certbot_nginx._internal import parser NAME_RANK = 0 START_WILDCARD_RANK = 1 @@ -70,7 +78,7 @@ class NginxConfigurator(common.Configurator): SSL_DIRECTIVES = ['ssl_certificate', 'ssl_certificate_key', 'ssl_dhparam'] @classmethod - def add_parser_arguments(cls, add): + def add_parser_arguments(cls, add: Callable[..., None]) -> None: default_server_root = _determine_default_server_root() add("server-root", default=constants.CLI_DEFAULTS["server_root"], help="Nginx server root directory. (default: %s)" % default_server_root) @@ -82,11 +90,11 @@ class NginxConfigurator(common.Configurator): "to apply when reloading.") @property - def nginx_conf(self): + def nginx_conf(self) -> str: """Nginx config file path.""" return os.path.join(self.conf("server_root"), "nginx.conf") - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: """Initialize an Nginx Configurator. :param tup version: version of Nginx as a tuple (1, 4, 7) @@ -125,7 +133,7 @@ class NginxConfigurator(common.Configurator): self.parser: parser.NginxParser @property - def mod_ssl_conf_src(self): + def mod_ssl_conf_src(self) -> str: """Full absolute path to SSL configuration file source.""" # Why all this complexity? Well, we want to support Mozilla's intermediate @@ -159,22 +167,23 @@ class NginxConfigurator(common.Configurator): "certbot_nginx", os.path.join("_internal", "tls_configs", config_filename)) @property - def mod_ssl_conf(self): + def mod_ssl_conf(self) -> str: """Full absolute path to SSL configuration file.""" return os.path.join(self.config.config_dir, constants.MOD_SSL_CONF_DEST) @property - def updated_mod_ssl_conf_digest(self): + def updated_mod_ssl_conf_digest(self) -> str: """Full absolute path to digest of updated SSL configuration file.""" return os.path.join(self.config.config_dir, constants.UPDATED_MOD_SSL_CONF_DIGEST) - def install_ssl_options_conf(self, options_ssl, options_ssl_digest): + def install_ssl_options_conf(self, options_ssl: str, options_ssl_digest: str) -> None: """Copy Certbot's SSL options file into the system's config dir if required.""" - return common.install_version_controlled_file(options_ssl, options_ssl_digest, + common.install_version_controlled_file( + options_ssl, options_ssl_digest, self.mod_ssl_conf_src, constants.ALL_SSL_OPTIONS_HASHES) # This is called in determine_authenticator and determine_installer - def prepare(self): + def prepare(self) -> None: """Prepare the authenticator/installer. :raises .errors.NoInstallationError: If Nginx ctl cannot be found @@ -210,8 +219,8 @@ class NginxConfigurator(common.Configurator): raise errors.PluginError('Unable to lock {0}'.format(self.conf('server-root'))) # Entry point in main.py for installing cert - def deploy_cert(self, domain, cert_path, key_path, - chain_path=None, fullchain_path=None): + def deploy_cert(self, domain: str, cert_path: str, key_path: str, chain_path: str, + fullchain_path: str) -> None: """Deploys certificate to specified virtual host. .. note:: Aborts if the vhost is missing ssl_certificate or @@ -234,7 +243,8 @@ class NginxConfigurator(common.Configurator): display_util.notify("Successfully deployed certificate for {} to {}" .format(domain, vhost.filep)) - def _deploy_cert(self, vhost, cert_path, key_path, chain_path, fullchain_path): # pylint: disable=unused-argument + def _deploy_cert(self, vhost: obj.VirtualHost, _cert_path: str, key_path: str, + _chain_path: str, fullchain_path: str) -> None: """ Helper function for deploy_cert() that handles the actual deployment this exists because we might want to do multiple deployments per @@ -244,8 +254,7 @@ class NginxConfigurator(common.Configurator): cert_directives = [['\n ', 'ssl_certificate', ' ', fullchain_path], ['\n ', 'ssl_certificate_key', ' ', key_path]] - self.parser.update_or_add_server_directives(vhost, - cert_directives) + self.parser.update_or_add_server_directives(vhost, cert_directives) logger.info("Deploying Certificate to VirtualHost %s", vhost.filep) self.save_notes += ("Changed vhost at %s with addresses of %s\n" % @@ -254,7 +263,8 @@ class NginxConfigurator(common.Configurator): self.save_notes += "\tssl_certificate %s\n" % fullchain_path self.save_notes += "\tssl_certificate_key %s\n" % key_path - def _choose_vhosts_wildcard(self, domain, prefer_ssl, no_ssl_filter_port=None): + def _choose_vhosts_wildcard(self, domain: str, prefer_ssl: bool, + no_ssl_filter_port: Optional[str] = None) -> List[obj.VirtualHost]: """Prompts user to choose vhosts to install a wildcard certificate for""" if prefer_ssl: vhosts_cache = self._wildcard_vhosts @@ -303,12 +313,13 @@ class NginxConfigurator(common.Configurator): ####################### # Vhost parsing methods ####################### - def _choose_vhost_single(self, target_name): + def _choose_vhost_single(self, target_name: str) -> List[obj.VirtualHost]: matches = self._get_ranked_matches(target_name) vhosts = [x for x in [self._select_best_name_match(matches)] if x is not None] return vhosts - def choose_vhosts(self, target_name, create_if_no_match=False): + def choose_vhosts(self, target_name: str, + create_if_no_match: bool = False) -> List[obj.VirtualHost]: """Chooses a virtual host based on the given domain name. .. note:: This makes the vhost SSL-enabled if it isn't already. Follows @@ -352,7 +363,7 @@ class NginxConfigurator(common.Configurator): return vhosts - def ipv6_info(self, port): + def ipv6_info(self, port: str) -> Tuple[bool, bool]: """Returns tuple of booleans (ipv6_active, ipv6only_present) ipv6_active is true if any server block listens ipv6 address in any port @@ -365,9 +376,6 @@ class NginxConfigurator(common.Configurator): configuration, and existence of ipv6only directive for specified port :rtype: tuple of type (bool, bool) """ - # port should be a string, but it's easy to mess up, so let's - # make sure it is one - port = str(port) vhosts = self.parser.get_vhosts() ipv6_active = False ipv6only_present = False @@ -377,10 +385,10 @@ class NginxConfigurator(common.Configurator): ipv6_active = True if addr.ipv6only and addr.get_port() == port: ipv6only_present = True - return (ipv6_active, ipv6only_present) + return ipv6_active, ipv6only_present - def _vhost_from_duplicated_default(self, domain: str, allow_port_mismatch: bool, port: str - ) -> obj.VirtualHost: + def _vhost_from_duplicated_default(self, domain: str, allow_port_mismatch: bool, + port: str) -> obj.VirtualHost: """if allow_port_mismatch is False, only server blocks with matching ports will be used as a default server block template. """ @@ -395,7 +403,7 @@ class NginxConfigurator(common.Configurator): self._add_server_name_to_vhost(self.new_vhost, domain) return self.new_vhost - def _add_server_name_to_vhost(self, vhost, domain): + def _add_server_name_to_vhost(self, vhost: obj.VirtualHost, domain: str) -> None: vhost.names.add(domain) name_block = [['\n ', 'server_name']] for name in vhost.names: @@ -403,7 +411,8 @@ class NginxConfigurator(common.Configurator): name_block[0].append(name) self.parser.update_or_add_server_directives(vhost, name_block) - def _get_default_vhost(self, domain, allow_port_mismatch, port): + def _get_default_vhost(self, domain: str, allow_port_mismatch: bool, + port: str) -> obj.VirtualHost: """Helper method for _vhost_from_duplicated_default; see argument documentation there""" vhost_list = self.parser.get_vhosts() # if one has default_server set, return that one @@ -424,10 +433,11 @@ class NginxConfigurator(common.Configurator): # TODO: present a list of vhosts for user to choose from - raise errors.MisconfigurationError("Could not automatically find a matching server" - " block for %s. Set the `server_name` directive to use the Nginx installer." % domain) + raise errors.MisconfigurationError("Could not automatically find a matching server " + f"block for {domain}. Set the `server_name` directive " + "to use the Nginx installer.") - def _get_ranked_matches(self, target_name): + def _get_ranked_matches(self, target_name: str) -> List[Dict[str, Any]]: """Returns a ranked list of vhosts that match target_name. The ranking gives preference to SSL vhosts. @@ -440,7 +450,8 @@ class NginxConfigurator(common.Configurator): vhost_list = self.parser.get_vhosts() return self._rank_matches_by_name_and_ssl(vhost_list, target_name) - def _select_best_name_match(self, matches): + def _select_best_name_match(self, + matches: Sequence[Mapping[str, Any]]) -> Optional[obj.VirtualHost]: """Returns the best name match of a ranked list of vhosts. :param list matches: list of dicts containing the vhost, the matching name, @@ -460,7 +471,8 @@ class NginxConfigurator(common.Configurator): # Exact or regex match return matches[0]['vhost'] - def _rank_matches_by_name(self, vhost_list, target_name): + def _rank_matches_by_name(self, vhost_list: Iterable[obj.VirtualHost], + target_name: str) -> List[Dict[str, Any]]: """Returns a ranked list of vhosts from vhost_list that match target_name. This method should always be followed by a call to _select_best_name_match. @@ -497,7 +509,8 @@ class NginxConfigurator(common.Configurator): 'rank': REGEX_RANK}) return sorted(matches, key=lambda x: x['rank']) - def _rank_matches_by_name_and_ssl(self, vhost_list, target_name): + def _rank_matches_by_name_and_ssl(self, vhost_list: Iterable[obj.VirtualHost], + target_name: str) -> List[Dict[str, Any]]: """Returns a ranked list of vhosts from vhost_list that match target_name. The ranking gives preference to SSLishness before name match level. @@ -610,7 +623,7 @@ class NginxConfigurator(common.Configurator): def _vhost_listening_on_port_no_ssl(self, vhost: obj.VirtualHost, port: str) -> bool: return self._vhost_listening(vhost, port, False) - def _get_redirect_ranked_matches(self, target_name, port): + def _get_redirect_ranked_matches(self, target_name: str, port: str) -> List[Dict[str, Any]]: """Gets a ranked list of plaintextish port-listening vhosts matching target_name Filter all hosts for those listening on port without using ssl. @@ -625,14 +638,14 @@ class NginxConfigurator(common.Configurator): """ all_vhosts = self.parser.get_vhosts() - def _vhost_matches(vhost, port): + def _vhost_matches(vhost: obj.VirtualHost, port: str) -> bool: return self._vhost_listening_on_port_no_ssl(vhost, port) matching_vhosts = [vhost for vhost in all_vhosts if _vhost_matches(vhost, port)] return self._rank_matches_by_name(matching_vhosts, target_name) - def get_all_names(self): + def get_all_names(self) -> Set[str]: """Returns all names found in the Nginx Configuration. :returns: All ServerNames, ServerAliases, and reverse DNS entries for @@ -670,7 +683,7 @@ class NginxConfigurator(common.Configurator): return util.get_filtered_names(all_names) - def _get_snakeoil_paths(self): + def _get_snakeoil_paths(self) -> Tuple[str, str]: """Generate invalid certs that let us create ssl directives for Nginx""" # TODO: generate only once tmp_dir = os.path.join(self.config.work_dir, "snakeoil") @@ -688,7 +701,7 @@ class NginxConfigurator(common.Configurator): cert_file.write(cert_pem) return cert_path, le_key.file - def _make_server_ssl(self, vhost): + def _make_server_ssl(self, vhost: obj.VirtualHost) -> None: """Make a server SSL. Make a server SSL by adding new listen and SSL directives. @@ -698,7 +711,7 @@ class NginxConfigurator(common.Configurator): """ https_port = self.config.https_port - ipv6info = self.ipv6_info(https_port) + ipv6info = self.ipv6_info(str(https_port)) ipv6_block = [''] ipv4_block = [''] @@ -745,11 +758,12 @@ class NginxConfigurator(common.Configurator): ################################## # enhancement methods (Installer) ################################## - def supported_enhancements(self): + def supported_enhancements(self) -> List[str]: """Returns currently supported enhancements.""" return ['redirect', 'ensure-http-header', 'staple-ocsp'] - def enhance(self, domain, enhancement, options=None): + def enhance(self, domain: str, enhancement: str, + options: Optional[Union[str, List[str]]] = None) -> None: """Enhance configuration. :param str domain: domain to enhance @@ -761,16 +775,16 @@ class NginxConfigurator(common.Configurator): """ try: - return self._enhance_func[enhancement](domain, options) + self._enhance_func[enhancement](domain, options) except (KeyError, ValueError): raise errors.PluginError( "Unsupported enhancement: {0}".format(enhancement)) - def _has_certbot_redirect(self, vhost, domain): + def _has_certbot_redirect(self, vhost: obj.VirtualHost, domain: str) -> bool: test_redirect_block = _test_block_from_block(_redirect_block_for_domain(domain)) return vhost.contains_list(test_redirect_block) - def _set_http_header(self, domain, header_substring): + def _set_http_header(self, domain: str, header_substring: Union[str, List[str], None]) -> None: """Enables header identified by header_substring on domain. If the vhost is listening plaintextishly, separates out the relevant @@ -784,7 +798,10 @@ class NginxConfigurator(common.Configurator): :raises .errors.PluginError: If no viable HTTPS host can be created or set with header header_substring. """ - if not header_substring in constants.HEADER_ARGS: + if not isinstance(header_substring, str): + raise errors.NotSupportedError("Invalid header_substring type " + f"{type(header_substring)}, expected a str.") + if header_substring not in constants.HEADER_ARGS: raise errors.NotSupportedError( f"{header_substring} is not supported by the nginx plugin.") @@ -808,7 +825,7 @@ class NginxConfigurator(common.Configurator): ['\n']] self.parser.add_server_directives(vhost, header_directives) - def _add_redirect_block(self, vhost, domain): + def _add_redirect_block(self, vhost: obj.VirtualHost, domain: str) -> None: """Add redirect directive to vhost """ redirect_block = _redirect_block_for_domain(domain) @@ -816,7 +833,8 @@ class NginxConfigurator(common.Configurator): self.parser.add_server_directives( vhost, redirect_block, insert_at_top=True) - def _split_block(self, vhost, only_directives=None): + def _split_block(self, vhost: obj.VirtualHost, only_directives: Optional[List[str]] = None + ) -> Tuple[obj.VirtualHost, obj.VirtualHost]: """Splits this "virtual host" (i.e. this nginx server block) into separate HTTP and HTTPS blocks. @@ -829,13 +847,13 @@ class NginxConfigurator(common.Configurator): """ http_vhost = self.parser.duplicate_vhost(vhost, only_directives=only_directives) - def _ssl_match_func(directive): + def _ssl_match_func(directive: str) -> bool: return 'ssl' in directive - def _ssl_config_match_func(directive): + def _ssl_config_match_func(directive: str) -> bool: return self.mod_ssl_conf in directive - def _no_ssl_match_func(directive): + def _no_ssl_match_func(directive: str) -> bool: return 'ssl' not in directive # remove all ssl addresses and related directives from the new block @@ -849,7 +867,8 @@ class NginxConfigurator(common.Configurator): self.parser.remove_server_directives(vhost, 'listen', match_func=_no_ssl_match_func) return http_vhost, vhost - def _enable_redirect(self, domain, unused_options): + def _enable_redirect(self, domain: str, + unused_options: Optional[Union[str, List[str]]]) -> None: """Redirect all equivalent HTTP traffic to ssl_vhost. If the vhost is listening plaintextishly, separate out the @@ -876,7 +895,7 @@ class NginxConfigurator(common.Configurator): for vhost in vhosts: self._enable_redirect_single(domain, vhost) - def _enable_redirect_single(self, domain, vhost): + def _enable_redirect_single(self, domain: str, vhost: obj.VirtualHost) -> None: """Redirect all equivalent HTTP traffic to ssl_vhost. If the vhost is listening plaintextishly, separate out the @@ -905,7 +924,8 @@ class NginxConfigurator(common.Configurator): logger.info("Redirecting all traffic on port %s to ssl in %s", self.DEFAULT_LISTEN_PORT, vhost.filep) - def _enable_ocsp_stapling(self, domain, chain_path): + def _enable_ocsp_stapling(self, domain: str, + chain_path: Optional[Union[str, List[str]]]) -> None: """Include OCSP response in TLS handshake :param str domain: domain to enable OCSP response for @@ -913,11 +933,15 @@ class NginxConfigurator(common.Configurator): :type chain_path: `str` or `None` """ + if not isinstance(chain_path, str) and chain_path is not None: + raise errors.NotSupportedError(f"Invalid chain_path type {type(chain_path)}, " + "expected a str or None.") vhosts = self.choose_vhosts(domain) for vhost in vhosts: self._enable_ocsp_stapling_single(vhost, chain_path) - def _enable_ocsp_stapling_single(self, vhost, chain_path): + def _enable_ocsp_stapling_single(self, vhost: obj.VirtualHost, + chain_path: Optional[str]) -> None: """Include OCSP response in TLS handshake :param str vhost: vhost to enable OCSP response for @@ -957,7 +981,7 @@ class NginxConfigurator(common.Configurator): ###################################### # Nginx server management (Installer) ###################################### - def restart(self): + def restart(self) -> None: """Restarts nginx server. :raises .errors.MisconfigurationError: If either the reload fails. @@ -965,7 +989,7 @@ class NginxConfigurator(common.Configurator): """ nginx_restart(self.conf('ctl'), self.nginx_conf, self.conf('sleep-seconds')) - def config_test(self): + def config_test(self) -> None: """Check the configuration of Nginx for errors. :raises .errors.MisconfigurationError: If config_test fails @@ -976,7 +1000,7 @@ class NginxConfigurator(common.Configurator): except errors.SubprocessError as err: raise errors.MisconfigurationError(str(err)) - def _nginx_version(self): + def _nginx_version(self) -> str: """Return results of nginx -V :returns: version text @@ -1000,7 +1024,7 @@ class NginxConfigurator(common.Configurator): "Unable to run %s -V" % self.conf('ctl')) return text - def get_version(self): + def get_version(self) -> Tuple[int, ...]: """Return version of Nginx Server. Version is returned as tuple. (ie. 2.4.7 = (2, 4, 7)) @@ -1045,7 +1069,7 @@ class NginxConfigurator(common.Configurator): return nginx_version - def _get_openssl_version(self): + def _get_openssl_version(self) -> str: """Return version of OpenSSL linked to Nginx. Version is returned as string. If no version can be found, empty string is returned. @@ -1067,7 +1091,7 @@ class NginxConfigurator(common.Configurator): return "" return matches[0] - def more_info(self): + def more_info(self) -> str: """Human-readable string to help understand the module""" return ( "Configures Nginx to authenticate and install HTTPS.{0}" @@ -1077,7 +1101,8 @@ class NginxConfigurator(common.Configurator): version=".".join(str(i) for i in self.version)) ) - def auth_hint(self, failed_achalls): # pragma: no cover + def auth_hint(self, # pragma: no cover + failed_achalls: Iterable[achallenges.AnnotatedChallenge]) -> str: return ( "The Certificate Authority failed to verify the temporary nginx configuration changes " "made by Certbot. Ensure the listed domains point to this nginx server and that it is " @@ -1087,7 +1112,7 @@ class NginxConfigurator(common.Configurator): ################################################### # Wrapper functions for Reverter class (Installer) ################################################### - def save(self, title=None, temporary=False): + def save(self, title: str = None, temporary: bool = False) -> None: """Saves all changes to the configuration files. :param str title: The title of the save. If a title is given, the @@ -1111,7 +1136,7 @@ class NginxConfigurator(common.Configurator): if title and not temporary: self.finalize_checkpoint(title) - def recovery_routine(self): + def recovery_routine(self) -> None: """Revert all previously modified files. Reverts all modified files that have not been saved as a checkpoint @@ -1123,7 +1148,7 @@ class NginxConfigurator(common.Configurator): self.new_vhost = None self.parser.load() - def revert_challenge_config(self): + def revert_challenge_config(self) -> None: """Used to cleanup challenge configurations. :raises .errors.PluginError: If unable to revert the challenge config. @@ -1133,7 +1158,7 @@ class NginxConfigurator(common.Configurator): self.new_vhost = None self.parser.load() - def rollback_checkpoints(self, rollback=1): + def rollback_checkpoints(self, rollback: int = 1) -> None: """Rollback saved checkpoints. :param int rollback: Number of checkpoints to revert @@ -1149,12 +1174,13 @@ class NginxConfigurator(common.Configurator): ########################################################################### # Challenges Section for Authenticator ########################################################################### - def get_chall_pref(self, unused_domain): + def get_chall_pref(self, unused_domain: str) -> List[Type[challenges.Challenge]]: """Return list of challenge preferences.""" return [challenges.HTTP01] # Entry point in main.py for performing challenges - def perform(self, achalls): + def perform(self, achalls: List[achallenges.AnnotatedChallenge] + ) -> List[challenges.HTTP01Response]: """Perform the configuration related challenge. This function currently assumes all challenges will be fulfilled. @@ -1163,7 +1189,7 @@ class NginxConfigurator(common.Configurator): """ self._chall_out += len(achalls) - responses = [None] * len(achalls) + responses: List[Optional[challenges.HTTP01Response]] = [None] * len(achalls) http_doer = http_01.NginxHttp01(self) for i, achall in enumerate(achalls): @@ -1183,10 +1209,10 @@ class NginxConfigurator(common.Configurator): for i, resp in enumerate(http_response): responses[http_doer.indices[i]] = resp - return responses + return [response for response in responses if response] # called after challenges are performed - def cleanup(self, achalls): + def cleanup(self, achalls: List[achallenges.AnnotatedChallenge]) -> None: """Revert all challenges.""" self._chall_out -= len(achalls) @@ -1196,13 +1222,13 @@ class NginxConfigurator(common.Configurator): self.restart() -def _test_block_from_block(block): +def _test_block_from_block(block: List[Any]) -> List[Any]: test_block = nginxparser.UnspacedList(block) parser.comment_directive(test_block, 0) return test_block[:-1] -def _redirect_block_for_domain(domain): +def _redirect_block_for_domain(domain: str) -> List[Any]: updated_domain = domain match_symbol = '=' if util.is_wildcard_domain(domain): @@ -1218,7 +1244,7 @@ def _redirect_block_for_domain(domain): return redirect_block -def nginx_restart(nginx_ctl, nginx_conf, sleep_duration): +def nginx_restart(nginx_ctl: str, nginx_conf: str, sleep_duration: int) -> None: """Restarts the Nginx Server. .. todo:: Nginx restart is fatal if the configuration references @@ -1263,10 +1289,10 @@ def nginx_restart(nginx_ctl, nginx_conf, sleep_duration): time.sleep(sleep_duration) -def _determine_default_server_root(): +def _determine_default_server_root() -> str: if os.environ.get("CERTBOT_DOCS") == "1": - default_server_root = "%s or %s" % (constants.LINUX_SERVER_ROOT, - constants.FREEBSD_DARWIN_SERVER_ROOT) + default_server_root = (f"{constants.LINUX_SERVER_ROOT} " + f"or {constants.FREEBSD_DARWIN_SERVER_ROOT}") else: default_server_root = constants.CLI_DEFAULTS["server_root"] return default_server_root diff --git a/certbot-nginx/certbot_nginx/_internal/constants.py b/certbot-nginx/certbot_nginx/_internal/constants.py index a0946b0fd..aa242903e 100644 --- a/certbot-nginx/certbot_nginx/_internal/constants.py +++ b/certbot-nginx/certbot_nginx/_internal/constants.py @@ -52,7 +52,8 @@ ALL_SSL_OPTIONS_HASHES = [ ] """SHA256 hashes of the contents of all versions of MOD_SSL_CONF_SRC""" -def os_constant(key): + +def os_constant(key: str) -> Any: # XXX TODO: In the future, this could return different constants # based on what OS we are running under. To see an # approach to how to handle different OSes, see the @@ -61,11 +62,12 @@ def os_constant(key): """ Get a constant value for operating system - :param key: name of cli constant + :param str key: name of cli constant :return: value of constant for active os """ return CLI_DEFAULTS[key] + HSTS_ARGS = ['\"max-age=31536000\"', ' ', 'always'] HEADER_ARGS = {'Strict-Transport-Security': HSTS_ARGS} diff --git a/certbot-nginx/certbot_nginx/_internal/display_ops.py b/certbot-nginx/certbot_nginx/_internal/display_ops.py index 95163a81f..89483c94a 100644 --- a/certbot-nginx/certbot_nginx/_internal/display_ops.py +++ b/certbot-nginx/certbot_nginx/_internal/display_ops.py @@ -1,12 +1,17 @@ """Contains UI methods for Nginx operations.""" import logging +from typing import Iterable +from typing import List +from typing import Optional + +from certbot_nginx._internal.obj import VirtualHost from certbot.display import util as display_util logger = logging.getLogger(__name__) -def select_vhost_multiple(vhosts): +def select_vhost_multiple(vhosts: Optional[Iterable[VirtualHost]]) -> List[VirtualHost]: """Select multiple Vhosts to install the certificate for :param vhosts: Available Nginx VirtualHosts :type vhosts: :class:`list` of type `~obj.Vhost` @@ -28,7 +33,7 @@ def select_vhost_multiple(vhosts): return [] -def _reversemap_vhosts(names, vhosts): +def _reversemap_vhosts(names: Iterable[str], vhosts: Iterable[VirtualHost]) -> List[VirtualHost]: """Helper function for select_vhost_multiple for mapping string representations back to actual vhost objects""" return_vhosts = [] diff --git a/certbot-nginx/certbot_nginx/_internal/http_01.py b/certbot-nginx/certbot_nginx/_internal/http_01.py index 95592fea0..6f61bfb6f 100644 --- a/certbot-nginx/certbot_nginx/_internal/http_01.py +++ b/certbot-nginx/certbot_nginx/_internal/http_01.py @@ -2,17 +2,20 @@ import io import logging +from typing import Any from typing import List from typing import Optional from typing import TYPE_CHECKING +from certbot_nginx._internal import nginxparser +from certbot_nginx._internal.obj import Addr + from acme import challenges -from certbot import achallenges +from acme.challenges import HTTP01Response from certbot import errors +from certbot.achallenges import KeyAuthorizationAnnotatedChallenge from certbot.compat import os from certbot.plugins import common -from certbot_nginx._internal import nginxparser -from certbot_nginx._internal import obj if TYPE_CHECKING: from certbot_nginx._internal.configurator import NginxConfigurator @@ -46,7 +49,7 @@ class NginxHttp01(common.ChallengePerformer): self.challenge_conf = os.path.join( configurator.config.config_dir, "le_http_01_cert_challenge.conf") - def perform(self): + def perform(self) -> List[HTTP01Response]: """Perform a challenge on Nginx. :returns: list of :class:`certbot.acme.challenges.HTTP01Response` @@ -66,7 +69,7 @@ class NginxHttp01(common.ChallengePerformer): return responses - def _mod_config(self): + def _mod_config(self) -> None: """Modifies Nginx config to include server_names_hash_bucket_size directive and server challenge blocks. @@ -113,39 +116,40 @@ class NginxHttp01(common.ChallengePerformer): with io.open(self.challenge_conf, "w", encoding="utf-8") as new_conf: nginxparser.dump(config, new_conf) - def _default_listen_addresses(self): + def _default_listen_addresses(self) -> List[Addr]: """Finds addresses for a challenge block to listen on. :returns: list of :class:`certbot_nginx._internal.obj.Addr` to apply :rtype: list """ - addresses: List[obj.Addr] = [] + addresses: List[Optional[Addr]] = [] default_addr = "%s" % self.configurator.config.http01_port ipv6_addr = "[::]:{0}".format( self.configurator.config.http01_port) port = self.configurator.config.http01_port - ipv6, ipv6only = self.configurator.ipv6_info(port) + ipv6, ipv6only = self.configurator.ipv6_info(str(port)) if ipv6: # If IPv6 is active in Nginx configuration if not ipv6only: # If ipv6only=on is not already present in the config ipv6_addr = ipv6_addr + " ipv6only=on" - addresses = [obj.Addr.fromstring(default_addr), - obj.Addr.fromstring(ipv6_addr)] + addresses = [Addr.fromstring(default_addr), + Addr.fromstring(ipv6_addr)] logger.debug(("Using default addresses %s and %s for authentication."), default_addr, ipv6_addr) else: - addresses = [obj.Addr.fromstring(default_addr)] + addresses = [Addr.fromstring(default_addr)] logger.debug("Using default address %s for authentication.", default_addr) - return addresses - def _get_validation_path(self, achall): + return [address for address in addresses if address] + + def _get_validation_path(self, achall: KeyAuthorizationAnnotatedChallenge) -> str: return os.sep + os.path.join(challenges.HTTP01.URI_ROOT_PATH, achall.chall.encode("token")) - def _make_server_block(self, achall: achallenges.KeyAuthorizationAnnotatedChallenge) -> List: + def _make_server_block(self, achall: KeyAuthorizationAnnotatedChallenge) -> List[Any]: """Creates a server block for a challenge. :param achall: Annotated HTTP-01 challenge @@ -168,7 +172,8 @@ class NginxHttp01(common.ChallengePerformer): # TODO: do we want to return something else if they otherwise access this block? return [['server'], block] - def _location_directive_for_achall(self, achall): + def _location_directive_for_achall(self, achall: KeyAuthorizationAnnotatedChallenge + ) -> List[Any]: validation = achall.validation(achall.account_key) validation_path = self._get_validation_path(achall) @@ -177,9 +182,8 @@ class NginxHttp01(common.ChallengePerformer): ['return', ' ', '200', ' ', validation]]] return location_directive - - def _make_or_mod_server_block(self, achall: achallenges.KeyAuthorizationAnnotatedChallenge - ) -> Optional[List]: + def _make_or_mod_server_block(self, achall: KeyAuthorizationAnnotatedChallenge + ) -> Optional[List[Any]]: """Modifies server blocks to respond to a challenge. Returns a new HTTP server block to add to the configuration if an existing one can't be found. @@ -192,7 +196,7 @@ class NginxHttp01(common.ChallengePerformer): """ http_vhosts, https_vhosts = self.configurator.choose_auth_vhosts(achall.domain) - new_vhost: Optional[list] = None + new_vhost: Optional[List[Any]] = None if not http_vhosts: # Couldn't find either a matching name+port server block # or a port+default_server block, so create a dummy block @@ -205,8 +209,8 @@ class NginxHttp01(common.ChallengePerformer): self.configurator.parser.add_server_directives(vhost, location_directive) rewrite_directive = [['rewrite', ' ', '^(/.well-known/acme-challenge/.*)', - ' ', '$1', ' ', 'break']] - self.configurator.parser.add_server_directives(vhost, - rewrite_directive, insert_at_top=True) + ' ', '$1', ' ', 'break']] + self.configurator.parser.add_server_directives( + vhost, rewrite_directive, insert_at_top=True) return new_vhost diff --git a/certbot-nginx/certbot_nginx/_internal/nginxparser.py b/certbot-nginx/certbot_nginx/_internal/nginxparser.py index 03aa88db4..70a55be3a 100644 --- a/certbot-nginx/certbot_nginx/_internal/nginxparser.py +++ b/certbot-nginx/certbot_nginx/_internal/nginxparser.py @@ -2,14 +2,23 @@ # Forked from https://github.com/fatiherikli/nginxparser (MIT Licensed) import copy import logging +import typing from typing import Any from typing import IO +from typing import Iterable +from typing import Iterator +from typing import List +from typing import overload +from typing import Tuple +from typing import TYPE_CHECKING +from typing import Union from pyparsing import Combine from pyparsing import Forward from pyparsing import Group from pyparsing import Literal from pyparsing import Optional +from pyparsing import ParseResults from pyparsing import QuotedString from pyparsing import Regex from pyparsing import restOfLine @@ -17,6 +26,9 @@ from pyparsing import stringEnd from pyparsing import White from pyparsing import ZeroOrMore +if TYPE_CHECKING: + from typing_extensions import SupportsIndex # typing.SupportsIndex not supported on Python 3.6 + logger = logging.getLogger(__name__) @@ -59,23 +71,24 @@ class RawNginxParser: script = ZeroOrMore(contents) + space + stringEnd script.parseWithTabs().leaveWhitespace() - def __init__(self, source): + def __init__(self, source: str) -> None: self.source = source - def parse(self): + def parse(self) -> ParseResults: """Returns the parsed tree.""" return self.script.parseString(self.source) - def as_list(self): + def as_list(self) -> List[Any]: """Returns the parsed tree as a list.""" return self.parse().asList() + class RawNginxDumper: """A class that dumps nginx configuration from the provided tree.""" - def __init__(self, blocks): + def __init__(self, blocks: List[Any]) -> None: self.blocks = blocks - def __iter__(self, blocks=None): + def __iter__(self, blocks: typing.Optional[List[Any]] = None) -> Iterator[str]: """Iterates the dumped nginx content.""" blocks = blocks or self.blocks for b0 in blocks: @@ -100,7 +113,7 @@ class RawNginxDumper: semicolon = "" yield "".join(item) + semicolon - def __str__(self): + def __str__(self) -> str: """Return the parsed block as a string.""" return ''.join(self) @@ -108,10 +121,10 @@ class RawNginxDumper: spacey = lambda x: (isinstance(x, str) and x.isspace()) or x == '' -class UnspacedList(list): +class UnspacedList(List[Any]): """Wrap a list [of lists], making any whitespace entries magically invisible""" - def __init__(self, list_source): + def __init__(self, list_source: Iterable[Any]) -> None: # ensure our argument is not a generator, and duplicate any sublists self.spaced = copy.deepcopy(list(list_source)) self.dirty = False @@ -122,14 +135,23 @@ class UnspacedList(list): for i, entry in reversed(list(enumerate(self))): if isinstance(entry, list): sublist = UnspacedList(entry) - list.__setitem__(self, i, sublist) + super().__setitem__(i, sublist) self.spaced[i] = sublist.spaced elif spacey(entry): # don't delete comments if "#" not in self[:i]: - list.__delitem__(self, i) + super().__delitem__(i) + + @overload + def _coerce(self, inbound: None) -> Tuple[None, None]: ... - def _coerce(self, inbound): + @overload + def _coerce(self, inbound: str) -> Tuple[str, str]: ... + + @overload + def _coerce(self, inbound: List[Any]) -> Tuple["UnspacedList", List[Any]]: ... + + def _coerce(self, inbound: Any) -> Tuple[Any, Any]: """ Coerce some inbound object to be appropriately usable in this object @@ -138,100 +160,114 @@ class UnspacedList(list): :rtype: tuple """ - if not isinstance(inbound, list): # str or None + if not isinstance(inbound, list): # str or None return inbound, inbound else: if not hasattr(inbound, "spaced"): inbound = UnspacedList(inbound) return inbound, inbound.spaced - def insert(self, i, x): + def insert(self, i: int, x: Any) -> None: + """Insert object before index.""" item, spaced_item = self._coerce(x) slicepos = self._spaced_position(i) if i < len(self) else len(self.spaced) self.spaced.insert(slicepos, spaced_item) if not spacey(item): - list.insert(self, i, item) + super().insert(i, item) self.dirty = True - def append(self, x): + def append(self, x: Any) -> None: + """Append object to the end of the list.""" item, spaced_item = self._coerce(x) self.spaced.append(spaced_item) if not spacey(item): - list.append(self, item) + super().append(item) self.dirty = True - def extend(self, x): + def extend(self, x: Any) -> None: + """Extend list by appending elements from the iterable.""" item, spaced_item = self._coerce(x) self.spaced.extend(spaced_item) - list.extend(self, item) + super().extend(item) self.dirty = True - def __add__(self, other): - l = copy.deepcopy(self) - l.extend(other) - l.dirty = True - return l + def __add__(self, other: List[Any]) -> "UnspacedList": + new_list = copy.deepcopy(self) + new_list.extend(other) + new_list.dirty = True + return new_list - def pop(self, _i=None): + def pop(self, *args: Any, **kwargs: Any) -> None: + """Function pop() is not implemented for UnspacedList""" raise NotImplementedError("UnspacedList.pop() not yet implemented") - def remove(self, _): + + def remove(self, *args: Any, **kwargs: Any) -> None: + """Function remove() is not implemented for UnspacedList""" raise NotImplementedError("UnspacedList.remove() not yet implemented") - def reverse(self): + + def reverse(self) -> None: + """Function reverse() is not implemented for UnspacedList""" raise NotImplementedError("UnspacedList.reverse() not yet implemented") - def sort(self, _cmp=None, _key=None, _Rev=None): + + def sort(self, *_args: Any, **_kwargs: Any) -> None: + """Function sort() is not implemented for UnspacedList""" raise NotImplementedError("UnspacedList.sort() not yet implemented") - def __setslice__(self, _i, _j, _newslice): + + def __setslice__(self, *args: Any, **kwargs: Any) -> None: raise NotImplementedError("Slice operations on UnspacedLists not yet implemented") - def __setitem__(self, i, value): + def __setitem__(self, i: Union["SupportsIndex", slice], value: Any) -> None: if isinstance(i, slice): raise NotImplementedError("Slice operations on UnspacedLists not yet implemented") item, spaced_item = self._coerce(value) self.spaced.__setitem__(self._spaced_position(i), spaced_item) if not spacey(item): - list.__setitem__(self, i, item) + super().__setitem__(i, item) self.dirty = True - def __delitem__(self, i): + def __delitem__(self, i: Union["SupportsIndex", slice]) -> None: + if isinstance(i, slice): + raise NotImplementedError("Slice operations on UnspacedLists not yet implemented") self.spaced.__delitem__(self._spaced_position(i)) - list.__delitem__(self, i) + super().__delitem__(i) self.dirty = True - def __deepcopy__(self, memo): + def __deepcopy__(self, memo: Any) -> "UnspacedList": new_spaced = copy.deepcopy(self.spaced, memo=memo) - l = UnspacedList(new_spaced) - l.dirty = self.dirty - return l + new_list = UnspacedList(new_spaced) + new_list.dirty = self.dirty + return new_list - def is_dirty(self): + def is_dirty(self) -> bool: """Recurse through the parse tree to figure out if any sublists are dirty""" if self.dirty: return True return any((isinstance(x, UnspacedList) and x.is_dirty() for x in self)) - def _spaced_position(self, idx): - "Convert from indexes in the unspaced list to positions in the spaced one" + def _spaced_position(self, idx: "SupportsIndex") -> int: + """Convert from indexes in the unspaced list to positions in the spaced one""" + int_idx = idx.__index__() pos = spaces = 0 # Normalize indexes like list[-1] etc, and save the result - if idx < 0: - idx = len(self) + idx - if not 0 <= idx < len(self): + if int_idx < 0: + int_idx = len(self) + int_idx + if not 0 <= int_idx < len(self): raise IndexError("list index out of range") - idx0 = idx - # Count the number of spaces in the spaced list before idx in the unspaced one - while idx != -1: + int_idx0 = int_idx + # Count the number of spaces in the spaced list before int_idx in the unspaced one + while int_idx != -1: if spacey(self.spaced[pos]): spaces += 1 else: - idx -= 1 + int_idx -= 1 pos += 1 - return idx0 + spaces + return int_idx0 + spaces # Shortcut functions to respect Python's serialization interface # (like pyyaml, picker or json) -def loads(source): +def loads(source: str) -> UnspacedList: """Parses from a string. :param str source: The string to parse @@ -242,34 +278,34 @@ def loads(source): return UnspacedList(RawNginxParser(source).as_list()) -def load(_file): +def load(file_: IO[Any]) -> UnspacedList: """Parses from a file. - :param file _file: The file to parse + :param file file_: The file to parse :returns: The parsed tree :rtype: list """ - return loads(_file.read()) + return loads(file_.read()) def dumps(blocks: UnspacedList) -> str: """Dump to a Unicode string. - :param UnspacedList block: The parsed tree + :param UnspacedList blocks: The parsed tree :rtype: six.text_type """ return str(RawNginxDumper(blocks.spaced)) -def dump(blocks: UnspacedList, _file: IO[Any]) -> None: +def dump(blocks: UnspacedList, file_: IO[Any]) -> None: """Dump to a file. - :param UnspacedList block: The parsed tree - :param IO[Any] _file: The file stream to dump to. It must be opened with + :param UnspacedList blocks: The parsed tree + :param IO[Any] file_: The file stream to dump to. It must be opened with Unicode encoding. :rtype: None """ - _file.write(dumps(blocks)) + file_.write(dumps(blocks)) diff --git a/certbot-nginx/certbot_nginx/_internal/obj.py b/certbot-nginx/certbot_nginx/_internal/obj.py index 44be0e598..1e0dbba1a 100644 --- a/certbot-nginx/certbot_nginx/_internal/obj.py +++ b/certbot-nginx/certbot_nginx/_internal/obj.py @@ -1,10 +1,17 @@ """Module contains classes used by the Nginx Configurator.""" import re +from typing import Any +from typing import List +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Union from certbot.plugins import common ADD_HEADER_DIRECTIVE = 'add_header' + class Addr(common.Addr): r"""Represents an Nginx address, i.e. what comes after the 'listen' directive. @@ -34,7 +41,8 @@ class Addr(common.Addr): UNSPECIFIED_IPV4_ADDRESSES = ('', '*', '0.0.0.0') CANONICAL_UNSPECIFIED_ADDRESS = UNSPECIFIED_IPV4_ADDRESSES[0] - def __init__(self, host, port, ssl, default, ipv6, ipv6only): + def __init__(self, host: str, port: str, ssl: bool, default: bool, + ipv6: bool, ipv6only: bool) -> None: super().__init__((host, port)) self.ssl = ssl self.default = default @@ -43,7 +51,7 @@ class Addr(common.Addr): self.unspecified_address = host in self.UNSPECIFIED_IPV4_ADDRESSES @classmethod - def fromstring(cls, str_addr): + def fromstring(cls, str_addr: str) -> Optional["Addr"]: """Initialize Addr from string.""" parts = str_addr.split(' ') ssl = False @@ -94,7 +102,7 @@ class Addr(common.Addr): return cls(host, port, ssl, default, ipv6, ipv6only) - def to_string(self, include_default=True): + def to_string(self, include_default: bool = True) -> str: """Return string representation of Addr""" parts = '' if self.tup[0] and self.tup[1]: @@ -111,18 +119,18 @@ class Addr(common.Addr): return parts - def __str__(self): + def __str__(self) -> str: return self.to_string() - def __repr__(self): + def __repr__(self) -> str: return "Addr(" + self.__str__() + ")" - def __hash__(self): # pylint: disable=useless-super-delegation + def __hash__(self) -> int: # pylint: disable=useless-super-delegation # Python 3 requires explicit overridden for __hash__ # See certbot-apache/certbot_apache/_internal/obj.py for more information return super().__hash__() - def super_eq(self, other): + def super_eq(self, other: "Addr") -> bool: """Check ip/port equality, with IPv6 support. """ # If both addresses got an unspecified address, then make sure the @@ -134,7 +142,7 @@ class Addr(common.Addr): other.tup[1]), other.ipv6) return super().__eq__(other) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, self.__class__): return (self.super_eq(other) and self.ssl == other.ssl and @@ -158,7 +166,8 @@ class VirtualHost: """ - def __init__(self, filep, addrs, ssl, enabled, names, raw, path): + def __init__(self, filep: str, addrs: Sequence[Addr], ssl: bool, enabled: bool, + names: Set[str], raw: List[Any], path: List[int]) -> None: """Initialize a VH.""" self.filep = filep self.addrs = addrs @@ -168,7 +177,7 @@ class VirtualHost: self.raw = raw self.path = path - def __str__(self): + def __str__(self) -> str: addr_str = ", ".join(str(addr) for addr in sorted(self.addrs, key=str)) # names might be a set, and it has different representations in Python # 2 and 3. Force it to be a list here for consistent outputs @@ -179,10 +188,10 @@ class VirtualHost: "enabled: %s" % (self.filep, addr_str, list(self.names), self.ssl, self.enabled)) - def __repr__(self): + def __repr__(self) -> str: return "VirtualHost(" + self.__str__().replace("\n", ", ") + ")\n" - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, self.__class__): return (self.filep == other.filep and sorted(self.addrs, key=str) == sorted(other.addrs, key=str) and @@ -193,12 +202,12 @@ class VirtualHost: return False - def __hash__(self): + def __hash__(self) -> int: return hash((self.filep, tuple(self.path), tuple(self.addrs), tuple(self.names), self.ssl, self.enabled)) - def has_header(self, header_name): + def has_header(self, header_name: str) -> bool: """Determine if this server block has a particular header set. :param str header_name: The name of the header to check for, e.g. 'Strict-Transport-Security' @@ -206,7 +215,7 @@ class VirtualHost: found = _find_directive(self.raw, ADD_HEADER_DIRECTIVE, header_name) return found is not None - def contains_list(self, test): + def contains_list(self, test: List[Any]) -> bool: """Determine if raw server block contains test list at top level """ for i in range(0, len(self.raw) - len(test) + 1): @@ -214,7 +223,7 @@ class VirtualHost: return True return False - def ipv6_enabled(self): + def ipv6_enabled(self) -> bool: """Return true if one or more of the listen directives in vhost supports IPv6""" for a in self.addrs: @@ -222,7 +231,7 @@ class VirtualHost: return True return False - def ipv4_enabled(self): + def ipv4_enabled(self) -> bool: """Return true if one or more of the listen directives in vhost are IPv4 only""" if not self.addrs: @@ -232,7 +241,7 @@ class VirtualHost: return True return False - def display_repr(self): + def display_repr(self) -> str: """Return a representation of VHost to be used in dialog""" return ( "File: {filename}\n" @@ -244,7 +253,9 @@ class VirtualHost: names=", ".join(self.names), https="Yes" if self.ssl else "No")) -def _find_directive(directives, directive_name, match_content=None): + +def _find_directive(directives: Optional[Union[str, List[Any]]], directive_name: str, + match_content: Optional[Any] = None) -> Optional[Any]: """Find a directive of type directive_name in directives. If match_content is given, Searches for `match_content` in the directive arguments. """ diff --git a/certbot-nginx/certbot_nginx/_internal/parser.py b/certbot-nginx/certbot_nginx/_internal/parser.py index 83ec4c923..e60d66eb2 100644 --- a/certbot-nginx/certbot_nginx/_internal/parser.py +++ b/certbot-nginx/certbot_nginx/_internal/parser.py @@ -5,20 +5,26 @@ import glob import io import logging import re +from typing import Any +from typing import Callable +from typing import cast from typing import Dict +from typing import Iterable from typing import List +from typing import Mapping from typing import Optional +from typing import Sequence from typing import Set from typing import Tuple from typing import Union +from certbot_nginx._internal import nginxparser +from certbot_nginx._internal import obj +from certbot_nginx._internal.nginxparser import UnspacedList import pyparsing from certbot import errors from certbot.compat import os -from certbot_nginx._internal import nginxparser -from certbot_nginx._internal import obj -from certbot_nginx._internal.nginxparser import UnspacedList logger = logging.getLogger(__name__) @@ -32,8 +38,8 @@ class NginxParser: """ - def __init__(self, root): - self.parsed: Dict[str, Union[List, nginxparser.UnspacedList]] = {} + def __init__(self, root: str) -> None: + self.parsed: Dict[str, UnspacedList] = {} self.root = os.path.abspath(root) self.config_root = self._find_config_root() @@ -42,14 +48,14 @@ class NginxParser: # not enable sites from there. self.load() - def load(self): + def load(self) -> None: """Loads Nginx files into a parsed tree. """ self.parsed = {} self._parse_recursively(self.config_root) - def _parse_recursively(self, filepath): + def _parse_recursively(self, filepath: str) -> None: """Parses nginx config files recursively by looking at 'include' directives inside 'http' and 'server' blocks. Note that this only reads Nginx files that potentially declare a virtual host. @@ -77,7 +83,7 @@ class NginxParser: if _is_include_directive(server_entry): self._parse_recursively(server_entry[1]) - def abs_path(self, path): + def abs_path(self, path: str) -> str: """Converts a relative path to an absolute path relative to the root. Does nothing for paths that are already absolute. @@ -90,7 +96,7 @@ class NginxParser: return os.path.normpath(os.path.join(self.root, path)) return os.path.normpath(path) - def _build_addr_to_ssl(self): + def _build_addr_to_ssl(self) -> Dict[Tuple[str, str], bool]: """Builds a map from address to whether it listens on ssl in any server block """ servers = self._get_raw_servers() @@ -107,11 +113,11 @@ class NginxParser: addr_to_ssl[addr_tuple] = addr.ssl or addr_to_ssl[addr_tuple] return addr_to_ssl - def _get_raw_servers(self) -> Dict: + def _get_raw_servers(self) -> Dict[str, Union[List[Any], UnspacedList]]: # pylint: disable=cell-var-from-loop """Get a map of unparsed all server blocks """ - servers: Dict[str, Union[List, nginxparser.UnspacedList]] = {} + servers: Dict[str, Union[List[Any], nginxparser.UnspacedList]] = {} for filename, tree in self.parsed.items(): servers[filename] = [] srv = servers[filename] # workaround undefined loop var in lambdas @@ -126,7 +132,7 @@ class NginxParser: servers[filename][i] = (new_server, path) return servers - def get_vhosts(self): + def get_vhosts(self) -> List[obj.VirtualHost]: """Gets list of all 'virtual hosts' found in Nginx configuration. Technically this is a misnomer because Nginx does not have virtual hosts, it has 'server blocks'. @@ -158,7 +164,7 @@ class NginxParser: return vhosts - def _update_vhosts_addrs_ssl(self, vhosts): + def _update_vhosts_addrs_ssl(self, vhosts: Iterable[obj.VirtualHost]) -> None: """Update a list of raw parsed vhosts to include global address sslishness """ addr_to_ssl = self._build_addr_to_ssl() @@ -168,7 +174,7 @@ class NginxParser: if addr.ssl: vhost.ssl = True - def _get_included_directives(self, block): + def _get_included_directives(self, block: UnspacedList) -> UnspacedList: """Returns array with the "include" directives expanded out by concatenating the contents of the included file to the block. @@ -188,7 +194,7 @@ class NginxParser: pass return result - def _parse_files(self, filepath, override=False): + def _parse_files(self, filepath: str, override: bool = False) -> List[UnspacedList]: """Parse files from a glob :param str filepath: Nginx config file path @@ -219,7 +225,7 @@ class NginxParser: logger.warning("Could not parse file: %s due to %s", item, err) return trees - def _find_config_root(self): + def _find_config_root(self) -> str: """Return the Nginx Configuration Root file.""" location = ['nginx.conf'] @@ -230,7 +236,7 @@ class NginxParser: raise errors.NoInstallationError( "Could not find Nginx root configuration file (nginx.conf)") - def filedump(self, ext='tmp', lazy=True): + def filedump(self, ext: str = 'tmp', lazy: bool = True) -> None: """Dumps parsed configurations into files. :param str ext: The file extension to use for the dumped files. If @@ -255,7 +261,7 @@ class NginxParser: except IOError: logger.error("Could not open file for writing: %s", filename) - def parse_server(self, server): + def parse_server(self, server: UnspacedList) -> Dict[str, Any]: """Parses a list of server directives, accounting for global address sslishness. :param list server: list of directives in a server block @@ -266,7 +272,7 @@ class NginxParser: _apply_global_addr_ssl(addr_to_ssl, parsed_server) return parsed_server - def has_ssl_on_directive(self, vhost): + def has_ssl_on_directive(self, vhost: obj.VirtualHost) -> bool: """Does vhost have ssl on for all ports? :param :class:`~certbot_nginx._internal.obj.VirtualHost` vhost: The vhost in question @@ -284,7 +290,8 @@ class NginxParser: return False - def add_server_directives(self, vhost, directives, insert_at_top=False): + def add_server_directives(self, vhost: obj.VirtualHost, directives: List[Any], + insert_at_top: bool = False) -> None: """Add directives to the server block identified by vhost. This method modifies vhost to be fully consistent with the new directives. @@ -305,7 +312,8 @@ class NginxParser: self._modify_server_directives(vhost, functools.partial(_add_directives, directives, insert_at_top)) - def update_or_add_server_directives(self, vhost, directives, insert_at_top=False): + def update_or_add_server_directives(self, vhost: obj.VirtualHost, directives: List[Any], + insert_at_top: bool = False) -> None: """Add or replace directives in the server block identified by vhost. This method modifies vhost to be fully consistent with the new directives. @@ -327,7 +335,8 @@ class NginxParser: self._modify_server_directives(vhost, functools.partial(_update_or_add_directives, directives, insert_at_top)) - def remove_server_directives(self, vhost, directive_name, match_func=None): + def remove_server_directives(self, vhost: obj.VirtualHost, directive_name: str, + match_func: Optional[Callable[[Any], bool]] = None) -> None: """Remove all directives of type directive_name. :param :class:`~certbot_nginx._internal.obj.VirtualHost` vhost: The vhost @@ -339,7 +348,8 @@ class NginxParser: self._modify_server_directives(vhost, functools.partial(_remove_directives, directive_name, match_func)) - def _update_vhost_based_on_new_directives(self, vhost, directives_list): + def _update_vhost_based_on_new_directives(self, vhost: obj.VirtualHost, + directives_list: UnspacedList) -> None: new_server = self._get_included_directives(directives_list) parsed_server = self.parse_server(new_server) vhost.addrs = parsed_server['addrs'] @@ -347,7 +357,8 @@ class NginxParser: vhost.names = parsed_server['names'] vhost.raw = new_server - def _modify_server_directives(self, vhost, block_func): + def _modify_server_directives(self, vhost: obj.VirtualHost, + block_func: Callable[[List[Any]], None]) -> None: filename = vhost.filep try: result = self.parsed[filename] @@ -364,7 +375,7 @@ class NginxParser: def duplicate_vhost(self, vhost_template: obj.VirtualHost, remove_singleton_listen_params: bool = False, - only_directives: Optional[List] = None) -> obj.VirtualHost: + only_directives: Optional[List[Any]] = None) -> obj.VirtualHost: """Duplicate the vhost in the configuration files. :param :class:`~certbot_nginx._internal.obj.VirtualHost` vhost_template: The vhost @@ -417,7 +428,7 @@ class NginxParser: return new_vhost -def _parse_ssl_options(ssl_options): +def _parse_ssl_options(ssl_options: Optional[str]) -> List[UnspacedList]: if ssl_options is not None: try: with io.open(ssl_options, "r", encoding="utf-8") as _file: @@ -429,9 +440,12 @@ def _parse_ssl_options(ssl_options): "Only UTF-8 encoding is supported.", ssl_options) except pyparsing.ParseBaseException as err: logger.warning("Could not parse file: %s due to %s", ssl_options, err) - return [] + return UnspacedList([]) + -def _do_for_subarray(entry, condition, func, path=None): +def _do_for_subarray(entry: List[Any], condition: Callable[[List[Any]], bool], + func: Callable[[List[Any], List[int]], None], + path: Optional[List[int]] = None) -> None: """Executes a function for a subarray of a nested array if it matches the given condition. @@ -450,7 +464,7 @@ def _do_for_subarray(entry, condition, func, path=None): _do_for_subarray(item, condition, func, path + [index]) -def get_best_match(target_name, names): +def get_best_match(target_name: str, names: Iterable[str]) -> Tuple[Optional[str], Optional[str]]: """Finds the best match for target_name out of names using the Nginx name-matching rules (exact > longest wildcard starting with * > longest wildcard ending with * > regex). @@ -479,29 +493,29 @@ def get_best_match(target_name, names): if exact: # There can be more than one exact match; e.g. eff.org, .eff.org match = min(exact, key=len) - return ('exact', match) + return 'exact', match if wildcard_start: # Return the longest wildcard match = max(wildcard_start, key=len) - return ('wildcard_start', match) + return 'wildcard_start', match if wildcard_end: # Return the longest wildcard match = max(wildcard_end, key=len) - return ('wildcard_end', match) + return 'wildcard_end', match if regex: # Just return the first one for now match = regex[0] - return ('regex', match) + return 'regex', match - return (None, None) + return None, None -def _exact_match(target_name, name): +def _exact_match(target_name: str, name: str) -> bool: target_lower = target_name.lower() return name.lower() in (target_lower, '.' + target_lower) -def _wildcard_match(target_name, name, start): +def _wildcard_match(target_name: str, name: str, start: bool) -> bool: # Degenerate case if name == '*': return True @@ -526,7 +540,7 @@ def _wildcard_match(target_name, name, start): return target_name_lower.endswith('.' + name_lower) -def _regex_match(target_name, name): +def _regex_match(target_name: str, name: str) -> bool: # Must start with a tilde if len(name) < 2 or name[0] != '~': return False @@ -534,13 +548,13 @@ def _regex_match(target_name, name): # After tilde is a perl-compatible regex try: regex = re.compile(name[1:]) - return re.match(regex, target_name) + return bool(re.match(regex, target_name)) except re.error: # pragma: no cover # perl-compatible regexes are sometimes not recognized by python return False -def _is_include_directive(entry): +def _is_include_directive(entry: Any) -> bool: """Checks if an nginx parsed entry is an 'include' directive. :param list entry: the parsed entry @@ -552,7 +566,8 @@ def _is_include_directive(entry): len(entry) == 2 and entry[0] == 'include' and isinstance(entry[1], str)) -def _is_ssl_on_directive(entry): + +def _is_ssl_on_directive(entry: Any) -> bool: """Checks if an nginx parsed entry is an 'ssl on' directive. :param list entry: the parsed entry @@ -564,14 +579,18 @@ def _is_ssl_on_directive(entry): len(entry) == 2 and entry[0] == 'ssl' and entry[1] == 'on') -def _add_directives(directives, insert_at_top, block): + +def _add_directives(directives: List[Any], insert_at_top: bool, + block: UnspacedList) -> None: """Adds directives to a config block.""" for directive in directives: _add_directive(block, directive, insert_at_top) if block and '\n' not in block[-1]: # could be " \n " or ["\n"] ! block.append(nginxparser.UnspacedList('\n')) -def _update_or_add_directives(directives, insert_at_top, block): + +def _update_or_add_directives(directives: List[Any], insert_at_top: bool, + block: UnspacedList) -> None: """Adds or replaces directives in a config block.""" for directive in directives: _update_or_add_directive(block, directive, insert_at_top) @@ -584,7 +603,8 @@ REPEATABLE_DIRECTIVES = {'server_name', 'listen', INCLUDE, 'rewrite', 'add_heade COMMENT = ' managed by Certbot' COMMENT_BLOCK = [' ', '#', COMMENT] -def comment_directive(block, location): + +def comment_directive(block: UnspacedList, location: int) -> None: """Add a ``#managed by Certbot`` comment to the end of the line at location. :param list block: The block containing the directive to be commented @@ -603,40 +623,45 @@ def comment_directive(block, location): if next_entry is not None and "\n" not in next_entry: block.insert(location + 2, '\n') -def _comment_out_directive(block, location, include_location): + +def _comment_out_directive(block: UnspacedList, location: int, include_location: str) -> None: """Comment out the line at location, with a note of explanation.""" comment_message = ' duplicated in {0}'.format(include_location) # add the end comment # create a dumpable object out of block[location] (so it includes the ;) directive = block[location] - new_dir_block = nginxparser.UnspacedList([]) # just a wrapper + new_dir_block = nginxparser.UnspacedList([]) # just a wrapper new_dir_block.append(directive) dumped = nginxparser.dumps(new_dir_block) - commented = dumped + ' #' + comment_message # add the comment directly to the one-line string - new_dir = nginxparser.loads(commented) # reload into UnspacedList + commented = dumped + ' #' + comment_message # add the comment directly to the one-line string + new_dir = nginxparser.loads(commented) # reload into UnspacedList # add the beginning comment insert_location = 0 - if new_dir[0].spaced[0] != new_dir[0][0]: # if there's whitespace at the beginning + if new_dir[0].spaced[0] != new_dir[0][0]: # if there's whitespace at the beginning insert_location = 1 - new_dir[0].spaced.insert(insert_location, "# ") # comment out the line - new_dir[0].spaced.append(";") # directly add in the ;, because now dumping won't work properly + new_dir[0].spaced.insert(insert_location, "# ") # comment out the line + new_dir[0].spaced.append(";") # directly add in the ;, because now dumping won't work properly dumped = nginxparser.dumps(new_dir) - new_dir = nginxparser.loads(dumped) # reload into an UnspacedList + new_dir = nginxparser.loads(dumped) # reload into an UnspacedList block[location] = new_dir[0] # set the now-single-line-comment directive back in place -def _find_location(block, directive_name, match_func=None): + +def _find_location(block: UnspacedList, directive_name: str, + match_func: Optional[Callable[[Any], bool]] = None) -> Optional[int]: """Finds the index of the first instance of directive_name in block. If no line exists, use None.""" - return next((index for index, line in enumerate(block) \ - if line and line[0] == directive_name and (match_func is None or match_func(line))), None) + return next((index for index, line in enumerate(block) if ( + line and line[0] == directive_name and (match_func is None or match_func(line)))), None) + -def _is_whitespace_or_comment(directive): +def _is_whitespace_or_comment(directive: Sequence[Any]) -> bool: """Is this directive either a whitespace or comment directive?""" return len(directive) == 0 or directive[0] == '#' -def _add_directive(block, directive, insert_at_top): + +def _add_directive(block: UnspacedList, directive: Sequence[Any], insert_at_top: bool) -> None: if not isinstance(directive, nginxparser.UnspacedList): directive = nginxparser.UnspacedList(directive) if _is_whitespace_or_comment(directive): @@ -653,10 +678,11 @@ def _add_directive(block, directive, insert_at_top): # handle flat include files directive_name = directive[0] - def can_append(loc, dir_name): + + def can_append(loc: Optional[int], dir_name: str) -> bool: """ Can we append this directive to the block? """ return loc is None or (isinstance(dir_name, str) - and dir_name in REPEATABLE_DIRECTIVES) + and dir_name in REPEATABLE_DIRECTIVES) err_fmt = 'tried to insert directive "{0}" but found conflicting "{1}".' @@ -672,10 +698,14 @@ def _add_directive(block, directive, insert_at_top): included_dir_name = included_directive[0] if (not _is_whitespace_or_comment(included_directive) and not can_append(included_dir_loc, included_dir_name)): - if block[included_dir_loc] != included_directive: - raise errors.MisconfigurationError(err_fmt.format(included_directive, - block[included_dir_loc])) - _comment_out_directive(block, included_dir_loc, directive[1]) + + # By construction of can_append(), included_dir_loc cannot be None at that point + resolved_included_dir_loc = cast(int, included_dir_loc) + + if block[resolved_included_dir_loc] != included_directive: + raise errors.MisconfigurationError(err_fmt.format( + included_directive, block[resolved_included_dir_loc])) + _comment_out_directive(block, resolved_included_dir_loc, directive[1]) if can_append(location, directive_name): if insert_at_top: @@ -687,14 +717,22 @@ def _add_directive(block, directive, insert_at_top): else: block.append(directive) comment_directive(block, len(block) - 1) - elif block[location] != directive: - raise errors.MisconfigurationError(err_fmt.format(directive, block[location])) + return + + # By construction of can_append(), location cannot be None at that point + resolved_location = cast(int, location) + + if block[resolved_location] != directive: + raise errors.MisconfigurationError(err_fmt.format(directive, block[resolved_location])) -def _update_directive(block, directive, location): + +def _update_directive(block: UnspacedList, directive: Sequence[Any], location: int) -> None: block[location] = directive comment_directive(block, location) -def _update_or_add_directive(block, directive, insert_at_top): + +def _update_or_add_directive(block: UnspacedList, directive: Sequence[Any], + insert_at_top: bool) -> None: if not isinstance(directive, nginxparser.UnspacedList): directive = nginxparser.UnspacedList(directive) if _is_whitespace_or_comment(directive): @@ -711,10 +749,13 @@ def _update_or_add_directive(block, directive, insert_at_top): _add_directive(block, directive, insert_at_top) -def _is_certbot_comment(directive): + +def _is_certbot_comment(directive: Sequence[Any]) -> bool: return '#' in directive and COMMENT in directive -def _remove_directives(directive_name, match_func, block): + +def _remove_directives(directive_name: str, match_func: Callable[[Any], bool], + block: UnspacedList) -> None: """Removes directives of name directive_name from a config block if match_func matches. """ while True: @@ -726,7 +767,9 @@ def _remove_directives(directive_name, match_func, block): del block[location + 1] del block[location] -def _apply_global_addr_ssl(addr_to_ssl, parsed_server): + +def _apply_global_addr_ssl(addr_to_ssl: Mapping[Tuple[str, str], bool], + parsed_server: Dict[str, Any]) -> None: """Apply global sslishness information to the parsed server block """ for addr in parsed_server['addrs']: @@ -734,7 +777,8 @@ def _apply_global_addr_ssl(addr_to_ssl, parsed_server): if addr.ssl: parsed_server['ssl'] = True -def _parse_server_raw(server): + +def _parse_server_raw(server: UnspacedList) -> Dict[str, Any]: """Parses a list of server directives. :param list server: list of directives in a server block diff --git a/certbot-nginx/certbot_nginx/_internal/parser_obj.py b/certbot-nginx/certbot_nginx/_internal/parser_obj.py index 33ed822c3..0af38a936 100644 --- a/certbot-nginx/certbot_nginx/_internal/parser_obj.py +++ b/certbot-nginx/certbot_nginx/_internal/parser_obj.py @@ -5,7 +5,14 @@ raw lists of tokens from pyparsing. """ import abc import logging +from typing import Any +from typing import Callable +from typing import Iterator from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type from certbot import errors @@ -23,24 +30,24 @@ class Parsable: __metaclass__ = abc.ABCMeta - def __init__(self, parent=None): - self._data: List[object] = [] + def __init__(self, parent: Optional["Parsable"] = None): + self._data: List[Any] = [] self._tabs = None self.parent = parent @classmethod - def parsing_hooks(cls): + def parsing_hooks(cls) -> Tuple[Type["Block"], Type["Sentence"], Type["Statements"]]: """Returns object types that this class should be able to `parse` recusrively. The order of the objects indicates the order in which the parser should try to parse each subitem. :returns: A list of Parsable classes. :rtype list: """ - return (Block, Sentence, Statements) + return Block, Sentence, Statements @staticmethod @abc.abstractmethod - def should_parse(lists): + def should_parse(lists: Any) -> bool: """ Returns whether the contents of `lists` can be parsed into this object. :returns: Whether `lists` can be parsed as this object. @@ -49,7 +56,7 @@ class Parsable: raise NotImplementedError() @abc.abstractmethod - def parse(self, raw_list, add_spaces=False): + def parse(self, raw_list: List[Any], add_spaces: bool = False) -> None: """ Loads information into this object from underlying raw_list structure. Each Parsable object might make different assumptions about the structure of raw_list. @@ -64,7 +71,8 @@ class Parsable: raise NotImplementedError() @abc.abstractmethod - def iterate(self, expanded=False, match=None): + def iterate(self, expanded: bool = False, + match: Optional[Callable[["Parsable"], bool]] = None) -> Iterator[Any]: """ Iterates across this object. If this object is a leaf object, only yields itself. If it contains references other parsing objects, and `expanded` is set, this function should first yield itself, then recursively iterate across all of them. @@ -77,7 +85,7 @@ class Parsable: raise NotImplementedError() @abc.abstractmethod - def get_tabs(self): + def get_tabs(self) -> str: """ Guess at the tabbing style of this parsed object, based on whitespace. If this object is a leaf, it deducts the tabbing based on its own contents. @@ -90,7 +98,7 @@ class Parsable: raise NotImplementedError() @abc.abstractmethod - def set_tabs(self, tabs=" "): + def set_tabs(self, tabs: str = " ") -> None: """This tries to set and alter the tabbing of the current object to a desired whitespace string. Primarily meant for objects that were constructed, so they can conform to surrounding whitespace. @@ -99,7 +107,7 @@ class Parsable: """ raise NotImplementedError() - def dump(self, include_spaces=False): + def dump(self, include_spaces: bool = False) -> List[Any]: """ Dumps back to pyparsing-like list tree. The opposite of `parse`. Note: if this object has not been modified, `dump` with `include_spaces=True` @@ -121,17 +129,17 @@ class Statements(Parsable): an extra `_trailing_whitespace` string to keep track of the whitespace that does not precede any more statements. """ - def __init__(self, parent=None): + def __init__(self, parent: Optional[Parsable] = None): super().__init__(parent) self._trailing_whitespace = None # ======== Begin overridden functions @staticmethod - def should_parse(lists): + def should_parse(lists: Any) -> bool: return isinstance(lists, list) - def set_tabs(self, tabs=" "): + def set_tabs(self, tabs: str = " ") -> None: """ Sets the tabbing for this set of statements. Does this by calling `set_tabs` on each of the child statements. @@ -144,7 +152,7 @@ class Statements(Parsable): if self.parent is not None: self._trailing_whitespace = "\n" + self.parent.get_tabs() - def parse(self, raw_list, add_spaces=False): + def parse(self, raw_list: List[Any], add_spaces: bool = False) -> None: """ Parses a list of statements. Expects all elements in `raw_list` to be parseable by `type(self).parsing_hooks`, with an optional whitespace string at the last index of `raw_list`. @@ -157,14 +165,14 @@ class Statements(Parsable): raw_list = raw_list[:-1] self._data = [parse_raw(elem, self, add_spaces) for elem in raw_list] - def get_tabs(self): + def get_tabs(self) -> str: """ Takes a guess at the tabbing of all contained Statements by retrieving the tabbing of the first Statement.""" if self._data: return self._data[0].get_tabs() return "" - def dump(self, include_spaces=False): + def dump(self, include_spaces: bool = False) -> List[Any]: """ Dumps this object by first dumping each statement, then appending its trailing whitespace (if `include_spaces` is set) """ data = super().dump(include_spaces) @@ -172,7 +180,8 @@ class Statements(Parsable): return data + [self._trailing_whitespace] return data - def iterate(self, expanded=False, match=None): + def iterate(self, expanded: bool = False, + match: Optional[Callable[["Parsable"], bool]] = None) -> Iterator[Any]: """ Combines each statement's iterator. """ for elem in self._data: for sub_elem in elem.iterate(expanded, match): @@ -181,7 +190,7 @@ class Statements(Parsable): # ======== End overridden functions -def _space_list(list_): +def _space_list(list_: Sequence[Any]) -> List[str]: """ Inserts whitespace between adjacent non-whitespace tokens. """ spaced_statement: List[str] = [] for i in reversed(range(len(list_))): @@ -197,7 +206,7 @@ class Sentence(Parsable): # ======== Begin overridden functions @staticmethod - def should_parse(lists): + def should_parse(lists: Any) -> bool: """ Returns True if `lists` can be parseable as a `Sentence`-- that is, every element is a string type. @@ -205,38 +214,39 @@ class Sentence(Parsable): :returns: whether this lists is parseable by `Sentence`. """ - return isinstance(lists, list) and len(lists) > 0 and \ - all(isinstance(elem, str) for elem in lists) + return (isinstance(lists, list) and len(lists) > 0 and + all(isinstance(elem, str) for elem in lists)) - def parse(self, raw_list, add_spaces=False): + def parse(self, raw_list: List[Any], add_spaces: bool = False) -> None: """ Parses a list of string types into this object. If add_spaces is set, adds whitespace tokens between adjacent non-whitespace tokens.""" if add_spaces: raw_list = _space_list(raw_list) - if not isinstance(raw_list, list) or \ - any(not isinstance(elem, str) for elem in raw_list): + if (not isinstance(raw_list, list) + or any(not isinstance(elem, str) for elem in raw_list)): raise errors.MisconfigurationError("Sentence parsing expects a list of string types.") self._data = raw_list - def iterate(self, expanded=False, match=None): + def iterate(self, expanded: bool = False, + match: Optional[Callable[[Parsable], bool]] = None) -> Iterator[Any]: """ Simply yields itself. """ if match is None or match(self): yield self - def set_tabs(self, tabs=" "): + def set_tabs(self, tabs: str = " ") -> None: """ Sets the tabbing on this sentence. Inserts a newline and `tabs` at the beginning of `self._data`. """ if self._data[0].isspace(): return self._data.insert(0, "\n" + tabs) - def dump(self, include_spaces=False): + def dump(self, include_spaces: bool = False) -> List[Any]: """ Dumps this sentence. If include_spaces is set, includes whitespace tokens.""" if not include_spaces: return self.words return self._data - def get_tabs(self): + def get_tabs(self) -> str: """ Guesses at the tabbing of this sentence. If the first element is whitespace, returns the whitespace after the rightmost newline in the string. """ first = self._data[0] @@ -248,14 +258,14 @@ class Sentence(Parsable): # ======== End overridden functions @property - def words(self): + def words(self) -> List[str]: """ Iterates over words, but without spaces. Like Unspaced List. """ return [word.strip("\"\'") for word in self._data if not word.isspace()] - def __getitem__(self, index): + def __getitem__(self, index: int) -> str: return self.words[index] - def __contains__(self, word): + def __contains__(self, word: str) -> bool: return word in self.words @@ -270,13 +280,13 @@ class Block(Parsable): names = ["block", " ", "name", " "] contents = [["\n ", "content", " ", "1"], ["\n ", "content", " ", "2"], "\n"] """ - def __init__(self, parent=None): + def __init__(self, parent: Optional[Parsable] = None) -> None: super().__init__(parent) - self.names: Sentence = None - self.contents: Block = None + self.names: Optional[Sentence] = None + self.contents: Optional[Block] = None @staticmethod - def should_parse(lists): + def should_parse(lists: Any) -> bool: """ Returns True if `lists` can be parseable as a `Block`-- that is, it's got a length of 2, the first element is a `Sentence` and the second can be a `Statements`. @@ -287,13 +297,14 @@ class Block(Parsable): return isinstance(lists, list) and len(lists) == 2 and \ Sentence.should_parse(lists[0]) and isinstance(lists[1], list) - def set_tabs(self, tabs=" "): + def set_tabs(self, tabs: str = " ") -> None: """ Sets tabs by setting equivalent tabbing on names, then adding tabbing to contents.""" self.names.set_tabs(tabs) self.contents.set_tabs(tabs + " ") - def iterate(self, expanded=False, match=None): + def iterate(self, expanded: bool = False, + match: Optional[Callable[[Parsable], bool]] = None) -> Iterator[Any]: """ Iterator over self, and if expanded is set, over its contents. """ if match is None or match(self): yield self @@ -301,7 +312,7 @@ class Block(Parsable): for elem in self.contents.iterate(expanded, match): yield elem - def parse(self, raw_list, add_spaces=False): + def parse(self, raw_list: List[Any], add_spaces: bool = False) -> None: """ Parses a list that resembles a block. The assumptions that this routine makes are: @@ -323,11 +334,12 @@ class Block(Parsable): self.contents.parse(raw_list[1], add_spaces) self._data = [self.names, self.contents] - def get_tabs(self): + def get_tabs(self) -> str: """ Guesses tabbing by retrieving tabbing guess of self.names. """ return self.names.get_tabs() -def _is_comment(parsed_obj): + +def _is_comment(parsed_obj: Parsable) -> bool: """ Checks whether parsed_obj is a comment. :param .Parsable parsed_obj: @@ -339,7 +351,8 @@ def _is_comment(parsed_obj): return False return parsed_obj.words[0] == "#" -def _is_certbot_comment(parsed_obj): + +def _is_certbot_comment(parsed_obj: Parsable) -> bool: """ Checks whether parsed_obj is a "managed by Certbot" comment. :param .Parsable parsed_obj: @@ -356,7 +369,8 @@ def _is_certbot_comment(parsed_obj): return False return True -def _certbot_comment(parent, preceding_spaces=4): + +def _certbot_comment(parent: Parsable, preceding_spaces: int = 4) -> Sentence: """ A "Managed by Certbot" comment. :param int preceding_spaces: Number of spaces between the end of the previous statement and the comment. @@ -367,7 +381,8 @@ def _certbot_comment(parent, preceding_spaces=4): result.parse([" " * preceding_spaces] + COMMENT_BLOCK) return result -def _choose_parser(parent, list_): + +def _choose_parser(parent: Parsable, list_: Any) -> Parsable: """ Choose a parser from type(parent).parsing_hooks, depending on whichever hook returns True first. """ hooks = Parsable.parsing_hooks() @@ -379,7 +394,8 @@ def _choose_parser(parent, list_): raise errors.MisconfigurationError( "None of the parsing hooks succeeded, so we don't know how to parse this set of lists.") -def parse_raw(lists_, parent=None, add_spaces=False): + +def parse_raw(lists_: Any, parent: Optional[Parsable] = None, add_spaces: bool = False) -> Parsable: """ Primary parsing factory function. :param list lists_: raw lists from pyparsing to parse. |