Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/certbot/certbot.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdrien Ferrand <adferrand@users.noreply.github.com>2022-01-13 03:36:51 +0300
committerGitHub <noreply@github.com>2022-01-13 03:36:51 +0300
commit16aad35d31a887dab157f9d4f5e0fe9218d06064 (patch)
tree067093f1e0523b9de843afad8aca718a3570a172 /certbot-nginx
parent30b066f08260b73fc26256b5484a180468b9d0a6 (diff)
Fully type certbot-nginx module (#9124)
* Work in progress * Fix type * Work in progress * Work in progress * Work in progress * Work in progress * Work in progress * Oups. * Fix typing in UnspacedList * Fix logic * Finish typing * List certbot-nginx as fully typed in tox * Fix lint * Fix checks * Organize imports * Fix typing for Python 3.6 * Fix checks * Fix lint * Update certbot-nginx/certbot_nginx/_internal/configurator.py Co-authored-by: alexzorin <alex@zor.io> * Update certbot-nginx/certbot_nginx/_internal/configurator.py Co-authored-by: alexzorin <alex@zor.io> * Fix signature of deploy_cert regarding the installer interface * Update certbot-nginx/certbot_nginx/_internal/obj.py Co-authored-by: alexzorin <alex@zor.io> * Fix types * Update certbot-nginx/certbot_nginx/_internal/parser.py Co-authored-by: alexzorin <alex@zor.io> * Precise type * Precise _coerce possible inputs/outputs * Fix type * Update certbot-nginx/certbot_nginx/_internal/http_01.py Co-authored-by: ohemorange <ebportnoy@gmail.com> * Fix type * Remove an undesirable implementation. * Fix type Co-authored-by: alexzorin <alex@zor.io> Co-authored-by: ohemorange <ebportnoy@gmail.com>
Diffstat (limited to 'certbot-nginx')
-rw-r--r--certbot-nginx/certbot_nginx/_internal/configurator.py190
-rw-r--r--certbot-nginx/certbot_nginx/_internal/constants.py6
-rw-r--r--certbot-nginx/certbot_nginx/_internal/display_ops.py9
-rw-r--r--certbot-nginx/certbot_nginx/_internal/http_01.py48
-rw-r--r--certbot-nginx/certbot_nginx/_internal/nginxparser.py148
-rw-r--r--certbot-nginx/certbot_nginx/_internal/obj.py49
-rw-r--r--certbot-nginx/certbot_nginx/_internal/parser.py184
-rw-r--r--certbot-nginx/certbot_nginx/_internal/parser_obj.py104
8 files changed, 441 insertions, 297 deletions
diff --git a/certbot-nginx/certbot_nginx/_internal/configurator.py b/certbot-nginx/certbot_nginx/_internal/configurator.py
index 46ec7d57e..fb819f194 100644
--- a/certbot-nginx/certbot_nginx/_internal/configurator.py
+++ b/certbot-nginx/certbot_nginx/_internal/configurator.py
@@ -6,30 +6,38 @@ import socket
import subprocess
import tempfile
import time
+from typing import Any
+from typing import Callable
from typing import Dict
+from typing import Iterable
from typing import List
+from typing import Mapping
from typing import Optional
+from typing import Sequence
from typing import Set
from typing import Text
from typing import Tuple
+from typing import Type
+from typing import Union
+from certbot_nginx._internal import constants
+from certbot_nginx._internal import display_ops
+from certbot_nginx._internal import http_01
+from certbot_nginx._internal import nginxparser
+from certbot_nginx._internal import obj
+from certbot_nginx._internal import parser
import OpenSSL
import pkg_resources
from acme import challenges
from acme import crypto_util as acme_crypto_util
+from certbot import achallenges
from certbot import crypto_util
from certbot import errors
from certbot import util
-from certbot.display import util as display_util
from certbot.compat import os
+from certbot.display import util as display_util
from certbot.plugins import common
-from certbot_nginx._internal import constants
-from certbot_nginx._internal import display_ops
-from certbot_nginx._internal import http_01
-from certbot_nginx._internal import nginxparser
-from certbot_nginx._internal import obj
-from certbot_nginx._internal import parser
NAME_RANK = 0
START_WILDCARD_RANK = 1
@@ -70,7 +78,7 @@ class NginxConfigurator(common.Configurator):
SSL_DIRECTIVES = ['ssl_certificate', 'ssl_certificate_key', 'ssl_dhparam']
@classmethod
- def add_parser_arguments(cls, add):
+ def add_parser_arguments(cls, add: Callable[..., None]) -> None:
default_server_root = _determine_default_server_root()
add("server-root", default=constants.CLI_DEFAULTS["server_root"],
help="Nginx server root directory. (default: %s)" % default_server_root)
@@ -82,11 +90,11 @@ class NginxConfigurator(common.Configurator):
"to apply when reloading.")
@property
- def nginx_conf(self):
+ def nginx_conf(self) -> str:
"""Nginx config file path."""
return os.path.join(self.conf("server_root"), "nginx.conf")
- def __init__(self, *args, **kwargs):
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize an Nginx Configurator.
:param tup version: version of Nginx as a tuple (1, 4, 7)
@@ -125,7 +133,7 @@ class NginxConfigurator(common.Configurator):
self.parser: parser.NginxParser
@property
- def mod_ssl_conf_src(self):
+ def mod_ssl_conf_src(self) -> str:
"""Full absolute path to SSL configuration file source."""
# Why all this complexity? Well, we want to support Mozilla's intermediate
@@ -159,22 +167,23 @@ class NginxConfigurator(common.Configurator):
"certbot_nginx", os.path.join("_internal", "tls_configs", config_filename))
@property
- def mod_ssl_conf(self):
+ def mod_ssl_conf(self) -> str:
"""Full absolute path to SSL configuration file."""
return os.path.join(self.config.config_dir, constants.MOD_SSL_CONF_DEST)
@property
- def updated_mod_ssl_conf_digest(self):
+ def updated_mod_ssl_conf_digest(self) -> str:
"""Full absolute path to digest of updated SSL configuration file."""
return os.path.join(self.config.config_dir, constants.UPDATED_MOD_SSL_CONF_DIGEST)
- def install_ssl_options_conf(self, options_ssl, options_ssl_digest):
+ def install_ssl_options_conf(self, options_ssl: str, options_ssl_digest: str) -> None:
"""Copy Certbot's SSL options file into the system's config dir if required."""
- return common.install_version_controlled_file(options_ssl, options_ssl_digest,
+ common.install_version_controlled_file(
+ options_ssl, options_ssl_digest,
self.mod_ssl_conf_src, constants.ALL_SSL_OPTIONS_HASHES)
# This is called in determine_authenticator and determine_installer
- def prepare(self):
+ def prepare(self) -> None:
"""Prepare the authenticator/installer.
:raises .errors.NoInstallationError: If Nginx ctl cannot be found
@@ -210,8 +219,8 @@ class NginxConfigurator(common.Configurator):
raise errors.PluginError('Unable to lock {0}'.format(self.conf('server-root')))
# Entry point in main.py for installing cert
- def deploy_cert(self, domain, cert_path, key_path,
- chain_path=None, fullchain_path=None):
+ def deploy_cert(self, domain: str, cert_path: str, key_path: str, chain_path: str,
+ fullchain_path: str) -> None:
"""Deploys certificate to specified virtual host.
.. note:: Aborts if the vhost is missing ssl_certificate or
@@ -234,7 +243,8 @@ class NginxConfigurator(common.Configurator):
display_util.notify("Successfully deployed certificate for {} to {}"
.format(domain, vhost.filep))
- def _deploy_cert(self, vhost, cert_path, key_path, chain_path, fullchain_path): # pylint: disable=unused-argument
+ def _deploy_cert(self, vhost: obj.VirtualHost, _cert_path: str, key_path: str,
+ _chain_path: str, fullchain_path: str) -> None:
"""
Helper function for deploy_cert() that handles the actual deployment
this exists because we might want to do multiple deployments per
@@ -244,8 +254,7 @@ class NginxConfigurator(common.Configurator):
cert_directives = [['\n ', 'ssl_certificate', ' ', fullchain_path],
['\n ', 'ssl_certificate_key', ' ', key_path]]
- self.parser.update_or_add_server_directives(vhost,
- cert_directives)
+ self.parser.update_or_add_server_directives(vhost, cert_directives)
logger.info("Deploying Certificate to VirtualHost %s", vhost.filep)
self.save_notes += ("Changed vhost at %s with addresses of %s\n" %
@@ -254,7 +263,8 @@ class NginxConfigurator(common.Configurator):
self.save_notes += "\tssl_certificate %s\n" % fullchain_path
self.save_notes += "\tssl_certificate_key %s\n" % key_path
- def _choose_vhosts_wildcard(self, domain, prefer_ssl, no_ssl_filter_port=None):
+ def _choose_vhosts_wildcard(self, domain: str, prefer_ssl: bool,
+ no_ssl_filter_port: Optional[str] = None) -> List[obj.VirtualHost]:
"""Prompts user to choose vhosts to install a wildcard certificate for"""
if prefer_ssl:
vhosts_cache = self._wildcard_vhosts
@@ -303,12 +313,13 @@ class NginxConfigurator(common.Configurator):
#######################
# Vhost parsing methods
#######################
- def _choose_vhost_single(self, target_name):
+ def _choose_vhost_single(self, target_name: str) -> List[obj.VirtualHost]:
matches = self._get_ranked_matches(target_name)
vhosts = [x for x in [self._select_best_name_match(matches)] if x is not None]
return vhosts
- def choose_vhosts(self, target_name, create_if_no_match=False):
+ def choose_vhosts(self, target_name: str,
+ create_if_no_match: bool = False) -> List[obj.VirtualHost]:
"""Chooses a virtual host based on the given domain name.
.. note:: This makes the vhost SSL-enabled if it isn't already. Follows
@@ -352,7 +363,7 @@ class NginxConfigurator(common.Configurator):
return vhosts
- def ipv6_info(self, port):
+ def ipv6_info(self, port: str) -> Tuple[bool, bool]:
"""Returns tuple of booleans (ipv6_active, ipv6only_present)
ipv6_active is true if any server block listens ipv6 address in any port
@@ -365,9 +376,6 @@ class NginxConfigurator(common.Configurator):
configuration, and existence of ipv6only directive for specified port
:rtype: tuple of type (bool, bool)
"""
- # port should be a string, but it's easy to mess up, so let's
- # make sure it is one
- port = str(port)
vhosts = self.parser.get_vhosts()
ipv6_active = False
ipv6only_present = False
@@ -377,10 +385,10 @@ class NginxConfigurator(common.Configurator):
ipv6_active = True
if addr.ipv6only and addr.get_port() == port:
ipv6only_present = True
- return (ipv6_active, ipv6only_present)
+ return ipv6_active, ipv6only_present
- def _vhost_from_duplicated_default(self, domain: str, allow_port_mismatch: bool, port: str
- ) -> obj.VirtualHost:
+ def _vhost_from_duplicated_default(self, domain: str, allow_port_mismatch: bool,
+ port: str) -> obj.VirtualHost:
"""if allow_port_mismatch is False, only server blocks with matching ports will be
used as a default server block template.
"""
@@ -395,7 +403,7 @@ class NginxConfigurator(common.Configurator):
self._add_server_name_to_vhost(self.new_vhost, domain)
return self.new_vhost
- def _add_server_name_to_vhost(self, vhost, domain):
+ def _add_server_name_to_vhost(self, vhost: obj.VirtualHost, domain: str) -> None:
vhost.names.add(domain)
name_block = [['\n ', 'server_name']]
for name in vhost.names:
@@ -403,7 +411,8 @@ class NginxConfigurator(common.Configurator):
name_block[0].append(name)
self.parser.update_or_add_server_directives(vhost, name_block)
- def _get_default_vhost(self, domain, allow_port_mismatch, port):
+ def _get_default_vhost(self, domain: str, allow_port_mismatch: bool,
+ port: str) -> obj.VirtualHost:
"""Helper method for _vhost_from_duplicated_default; see argument documentation there"""
vhost_list = self.parser.get_vhosts()
# if one has default_server set, return that one
@@ -424,10 +433,11 @@ class NginxConfigurator(common.Configurator):
# TODO: present a list of vhosts for user to choose from
- raise errors.MisconfigurationError("Could not automatically find a matching server"
- " block for %s. Set the `server_name` directive to use the Nginx installer." % domain)
+ raise errors.MisconfigurationError("Could not automatically find a matching server "
+ f"block for {domain}. Set the `server_name` directive "
+ "to use the Nginx installer.")
- def _get_ranked_matches(self, target_name):
+ def _get_ranked_matches(self, target_name: str) -> List[Dict[str, Any]]:
"""Returns a ranked list of vhosts that match target_name.
The ranking gives preference to SSL vhosts.
@@ -440,7 +450,8 @@ class NginxConfigurator(common.Configurator):
vhost_list = self.parser.get_vhosts()
return self._rank_matches_by_name_and_ssl(vhost_list, target_name)
- def _select_best_name_match(self, matches):
+ def _select_best_name_match(self,
+ matches: Sequence[Mapping[str, Any]]) -> Optional[obj.VirtualHost]:
"""Returns the best name match of a ranked list of vhosts.
:param list matches: list of dicts containing the vhost, the matching name,
@@ -460,7 +471,8 @@ class NginxConfigurator(common.Configurator):
# Exact or regex match
return matches[0]['vhost']
- def _rank_matches_by_name(self, vhost_list, target_name):
+ def _rank_matches_by_name(self, vhost_list: Iterable[obj.VirtualHost],
+ target_name: str) -> List[Dict[str, Any]]:
"""Returns a ranked list of vhosts from vhost_list that match target_name.
This method should always be followed by a call to _select_best_name_match.
@@ -497,7 +509,8 @@ class NginxConfigurator(common.Configurator):
'rank': REGEX_RANK})
return sorted(matches, key=lambda x: x['rank'])
- def _rank_matches_by_name_and_ssl(self, vhost_list, target_name):
+ def _rank_matches_by_name_and_ssl(self, vhost_list: Iterable[obj.VirtualHost],
+ target_name: str) -> List[Dict[str, Any]]:
"""Returns a ranked list of vhosts from vhost_list that match target_name.
The ranking gives preference to SSLishness before name match level.
@@ -610,7 +623,7 @@ class NginxConfigurator(common.Configurator):
def _vhost_listening_on_port_no_ssl(self, vhost: obj.VirtualHost, port: str) -> bool:
return self._vhost_listening(vhost, port, False)
- def _get_redirect_ranked_matches(self, target_name, port):
+ def _get_redirect_ranked_matches(self, target_name: str, port: str) -> List[Dict[str, Any]]:
"""Gets a ranked list of plaintextish port-listening vhosts matching target_name
Filter all hosts for those listening on port without using ssl.
@@ -625,14 +638,14 @@ class NginxConfigurator(common.Configurator):
"""
all_vhosts = self.parser.get_vhosts()
- def _vhost_matches(vhost, port):
+ def _vhost_matches(vhost: obj.VirtualHost, port: str) -> bool:
return self._vhost_listening_on_port_no_ssl(vhost, port)
matching_vhosts = [vhost for vhost in all_vhosts if _vhost_matches(vhost, port)]
return self._rank_matches_by_name(matching_vhosts, target_name)
- def get_all_names(self):
+ def get_all_names(self) -> Set[str]:
"""Returns all names found in the Nginx Configuration.
:returns: All ServerNames, ServerAliases, and reverse DNS entries for
@@ -670,7 +683,7 @@ class NginxConfigurator(common.Configurator):
return util.get_filtered_names(all_names)
- def _get_snakeoil_paths(self):
+ def _get_snakeoil_paths(self) -> Tuple[str, str]:
"""Generate invalid certs that let us create ssl directives for Nginx"""
# TODO: generate only once
tmp_dir = os.path.join(self.config.work_dir, "snakeoil")
@@ -688,7 +701,7 @@ class NginxConfigurator(common.Configurator):
cert_file.write(cert_pem)
return cert_path, le_key.file
- def _make_server_ssl(self, vhost):
+ def _make_server_ssl(self, vhost: obj.VirtualHost) -> None:
"""Make a server SSL.
Make a server SSL by adding new listen and SSL directives.
@@ -698,7 +711,7 @@ class NginxConfigurator(common.Configurator):
"""
https_port = self.config.https_port
- ipv6info = self.ipv6_info(https_port)
+ ipv6info = self.ipv6_info(str(https_port))
ipv6_block = ['']
ipv4_block = ['']
@@ -745,11 +758,12 @@ class NginxConfigurator(common.Configurator):
##################################
# enhancement methods (Installer)
##################################
- def supported_enhancements(self):
+ def supported_enhancements(self) -> List[str]:
"""Returns currently supported enhancements."""
return ['redirect', 'ensure-http-header', 'staple-ocsp']
- def enhance(self, domain, enhancement, options=None):
+ def enhance(self, domain: str, enhancement: str,
+ options: Optional[Union[str, List[str]]] = None) -> None:
"""Enhance configuration.
:param str domain: domain to enhance
@@ -761,16 +775,16 @@ class NginxConfigurator(common.Configurator):
"""
try:
- return self._enhance_func[enhancement](domain, options)
+ self._enhance_func[enhancement](domain, options)
except (KeyError, ValueError):
raise errors.PluginError(
"Unsupported enhancement: {0}".format(enhancement))
- def _has_certbot_redirect(self, vhost, domain):
+ def _has_certbot_redirect(self, vhost: obj.VirtualHost, domain: str) -> bool:
test_redirect_block = _test_block_from_block(_redirect_block_for_domain(domain))
return vhost.contains_list(test_redirect_block)
- def _set_http_header(self, domain, header_substring):
+ def _set_http_header(self, domain: str, header_substring: Union[str, List[str], None]) -> None:
"""Enables header identified by header_substring on domain.
If the vhost is listening plaintextishly, separates out the relevant
@@ -784,7 +798,10 @@ class NginxConfigurator(common.Configurator):
:raises .errors.PluginError: If no viable HTTPS host can be created or
set with header header_substring.
"""
- if not header_substring in constants.HEADER_ARGS:
+ if not isinstance(header_substring, str):
+ raise errors.NotSupportedError("Invalid header_substring type "
+ f"{type(header_substring)}, expected a str.")
+ if header_substring not in constants.HEADER_ARGS:
raise errors.NotSupportedError(
f"{header_substring} is not supported by the nginx plugin.")
@@ -808,7 +825,7 @@ class NginxConfigurator(common.Configurator):
['\n']]
self.parser.add_server_directives(vhost, header_directives)
- def _add_redirect_block(self, vhost, domain):
+ def _add_redirect_block(self, vhost: obj.VirtualHost, domain: str) -> None:
"""Add redirect directive to vhost
"""
redirect_block = _redirect_block_for_domain(domain)
@@ -816,7 +833,8 @@ class NginxConfigurator(common.Configurator):
self.parser.add_server_directives(
vhost, redirect_block, insert_at_top=True)
- def _split_block(self, vhost, only_directives=None):
+ def _split_block(self, vhost: obj.VirtualHost, only_directives: Optional[List[str]] = None
+ ) -> Tuple[obj.VirtualHost, obj.VirtualHost]:
"""Splits this "virtual host" (i.e. this nginx server block) into
separate HTTP and HTTPS blocks.
@@ -829,13 +847,13 @@ class NginxConfigurator(common.Configurator):
"""
http_vhost = self.parser.duplicate_vhost(vhost, only_directives=only_directives)
- def _ssl_match_func(directive):
+ def _ssl_match_func(directive: str) -> bool:
return 'ssl' in directive
- def _ssl_config_match_func(directive):
+ def _ssl_config_match_func(directive: str) -> bool:
return self.mod_ssl_conf in directive
- def _no_ssl_match_func(directive):
+ def _no_ssl_match_func(directive: str) -> bool:
return 'ssl' not in directive
# remove all ssl addresses and related directives from the new block
@@ -849,7 +867,8 @@ class NginxConfigurator(common.Configurator):
self.parser.remove_server_directives(vhost, 'listen', match_func=_no_ssl_match_func)
return http_vhost, vhost
- def _enable_redirect(self, domain, unused_options):
+ def _enable_redirect(self, domain: str,
+ unused_options: Optional[Union[str, List[str]]]) -> None:
"""Redirect all equivalent HTTP traffic to ssl_vhost.
If the vhost is listening plaintextishly, separate out the
@@ -876,7 +895,7 @@ class NginxConfigurator(common.Configurator):
for vhost in vhosts:
self._enable_redirect_single(domain, vhost)
- def _enable_redirect_single(self, domain, vhost):
+ def _enable_redirect_single(self, domain: str, vhost: obj.VirtualHost) -> None:
"""Redirect all equivalent HTTP traffic to ssl_vhost.
If the vhost is listening plaintextishly, separate out the
@@ -905,7 +924,8 @@ class NginxConfigurator(common.Configurator):
logger.info("Redirecting all traffic on port %s to ssl in %s",
self.DEFAULT_LISTEN_PORT, vhost.filep)
- def _enable_ocsp_stapling(self, domain, chain_path):
+ def _enable_ocsp_stapling(self, domain: str,
+ chain_path: Optional[Union[str, List[str]]]) -> None:
"""Include OCSP response in TLS handshake
:param str domain: domain to enable OCSP response for
@@ -913,11 +933,15 @@ class NginxConfigurator(common.Configurator):
:type chain_path: `str` or `None`
"""
+ if not isinstance(chain_path, str) and chain_path is not None:
+ raise errors.NotSupportedError(f"Invalid chain_path type {type(chain_path)}, "
+ "expected a str or None.")
vhosts = self.choose_vhosts(domain)
for vhost in vhosts:
self._enable_ocsp_stapling_single(vhost, chain_path)
- def _enable_ocsp_stapling_single(self, vhost, chain_path):
+ def _enable_ocsp_stapling_single(self, vhost: obj.VirtualHost,
+ chain_path: Optional[str]) -> None:
"""Include OCSP response in TLS handshake
:param str vhost: vhost to enable OCSP response for
@@ -957,7 +981,7 @@ class NginxConfigurator(common.Configurator):
######################################
# Nginx server management (Installer)
######################################
- def restart(self):
+ def restart(self) -> None:
"""Restarts nginx server.
:raises .errors.MisconfigurationError: If either the reload fails.
@@ -965,7 +989,7 @@ class NginxConfigurator(common.Configurator):
"""
nginx_restart(self.conf('ctl'), self.nginx_conf, self.conf('sleep-seconds'))
- def config_test(self):
+ def config_test(self) -> None:
"""Check the configuration of Nginx for errors.
:raises .errors.MisconfigurationError: If config_test fails
@@ -976,7 +1000,7 @@ class NginxConfigurator(common.Configurator):
except errors.SubprocessError as err:
raise errors.MisconfigurationError(str(err))
- def _nginx_version(self):
+ def _nginx_version(self) -> str:
"""Return results of nginx -V
:returns: version text
@@ -1000,7 +1024,7 @@ class NginxConfigurator(common.Configurator):
"Unable to run %s -V" % self.conf('ctl'))
return text
- def get_version(self):
+ def get_version(self) -> Tuple[int, ...]:
"""Return version of Nginx Server.
Version is returned as tuple. (ie. 2.4.7 = (2, 4, 7))
@@ -1045,7 +1069,7 @@ class NginxConfigurator(common.Configurator):
return nginx_version
- def _get_openssl_version(self):
+ def _get_openssl_version(self) -> str:
"""Return version of OpenSSL linked to Nginx.
Version is returned as string. If no version can be found, empty string is returned.
@@ -1067,7 +1091,7 @@ class NginxConfigurator(common.Configurator):
return ""
return matches[0]
- def more_info(self):
+ def more_info(self) -> str:
"""Human-readable string to help understand the module"""
return (
"Configures Nginx to authenticate and install HTTPS.{0}"
@@ -1077,7 +1101,8 @@ class NginxConfigurator(common.Configurator):
version=".".join(str(i) for i in self.version))
)
- def auth_hint(self, failed_achalls): # pragma: no cover
+ def auth_hint(self, # pragma: no cover
+ failed_achalls: Iterable[achallenges.AnnotatedChallenge]) -> str:
return (
"The Certificate Authority failed to verify the temporary nginx configuration changes "
"made by Certbot. Ensure the listed domains point to this nginx server and that it is "
@@ -1087,7 +1112,7 @@ class NginxConfigurator(common.Configurator):
###################################################
# Wrapper functions for Reverter class (Installer)
###################################################
- def save(self, title=None, temporary=False):
+ def save(self, title: str = None, temporary: bool = False) -> None:
"""Saves all changes to the configuration files.
:param str title: The title of the save. If a title is given, the
@@ -1111,7 +1136,7 @@ class NginxConfigurator(common.Configurator):
if title and not temporary:
self.finalize_checkpoint(title)
- def recovery_routine(self):
+ def recovery_routine(self) -> None:
"""Revert all previously modified files.
Reverts all modified files that have not been saved as a checkpoint
@@ -1123,7 +1148,7 @@ class NginxConfigurator(common.Configurator):
self.new_vhost = None
self.parser.load()
- def revert_challenge_config(self):
+ def revert_challenge_config(self) -> None:
"""Used to cleanup challenge configurations.
:raises .errors.PluginError: If unable to revert the challenge config.
@@ -1133,7 +1158,7 @@ class NginxConfigurator(common.Configurator):
self.new_vhost = None
self.parser.load()
- def rollback_checkpoints(self, rollback=1):
+ def rollback_checkpoints(self, rollback: int = 1) -> None:
"""Rollback saved checkpoints.
:param int rollback: Number of checkpoints to revert
@@ -1149,12 +1174,13 @@ class NginxConfigurator(common.Configurator):
###########################################################################
# Challenges Section for Authenticator
###########################################################################
- def get_chall_pref(self, unused_domain):
+ def get_chall_pref(self, unused_domain: str) -> List[Type[challenges.Challenge]]:
"""Return list of challenge preferences."""
return [challenges.HTTP01]
# Entry point in main.py for performing challenges
- def perform(self, achalls):
+ def perform(self, achalls: List[achallenges.AnnotatedChallenge]
+ ) -> List[challenges.HTTP01Response]:
"""Perform the configuration related challenge.
This function currently assumes all challenges will be fulfilled.
@@ -1163,7 +1189,7 @@ class NginxConfigurator(common.Configurator):
"""
self._chall_out += len(achalls)
- responses = [None] * len(achalls)
+ responses: List[Optional[challenges.HTTP01Response]] = [None] * len(achalls)
http_doer = http_01.NginxHttp01(self)
for i, achall in enumerate(achalls):
@@ -1183,10 +1209,10 @@ class NginxConfigurator(common.Configurator):
for i, resp in enumerate(http_response):
responses[http_doer.indices[i]] = resp
- return responses
+ return [response for response in responses if response]
# called after challenges are performed
- def cleanup(self, achalls):
+ def cleanup(self, achalls: List[achallenges.AnnotatedChallenge]) -> None:
"""Revert all challenges."""
self._chall_out -= len(achalls)
@@ -1196,13 +1222,13 @@ class NginxConfigurator(common.Configurator):
self.restart()
-def _test_block_from_block(block):
+def _test_block_from_block(block: List[Any]) -> List[Any]:
test_block = nginxparser.UnspacedList(block)
parser.comment_directive(test_block, 0)
return test_block[:-1]
-def _redirect_block_for_domain(domain):
+def _redirect_block_for_domain(domain: str) -> List[Any]:
updated_domain = domain
match_symbol = '='
if util.is_wildcard_domain(domain):
@@ -1218,7 +1244,7 @@ def _redirect_block_for_domain(domain):
return redirect_block
-def nginx_restart(nginx_ctl, nginx_conf, sleep_duration):
+def nginx_restart(nginx_ctl: str, nginx_conf: str, sleep_duration: int) -> None:
"""Restarts the Nginx Server.
.. todo:: Nginx restart is fatal if the configuration references
@@ -1263,10 +1289,10 @@ def nginx_restart(nginx_ctl, nginx_conf, sleep_duration):
time.sleep(sleep_duration)
-def _determine_default_server_root():
+def _determine_default_server_root() -> str:
if os.environ.get("CERTBOT_DOCS") == "1":
- default_server_root = "%s or %s" % (constants.LINUX_SERVER_ROOT,
- constants.FREEBSD_DARWIN_SERVER_ROOT)
+ default_server_root = (f"{constants.LINUX_SERVER_ROOT} "
+ f"or {constants.FREEBSD_DARWIN_SERVER_ROOT}")
else:
default_server_root = constants.CLI_DEFAULTS["server_root"]
return default_server_root
diff --git a/certbot-nginx/certbot_nginx/_internal/constants.py b/certbot-nginx/certbot_nginx/_internal/constants.py
index a0946b0fd..aa242903e 100644
--- a/certbot-nginx/certbot_nginx/_internal/constants.py
+++ b/certbot-nginx/certbot_nginx/_internal/constants.py
@@ -52,7 +52,8 @@ ALL_SSL_OPTIONS_HASHES = [
]
"""SHA256 hashes of the contents of all versions of MOD_SSL_CONF_SRC"""
-def os_constant(key):
+
+def os_constant(key: str) -> Any:
# XXX TODO: In the future, this could return different constants
# based on what OS we are running under. To see an
# approach to how to handle different OSes, see the
@@ -61,11 +62,12 @@ def os_constant(key):
"""
Get a constant value for operating system
- :param key: name of cli constant
+ :param str key: name of cli constant
:return: value of constant for active os
"""
return CLI_DEFAULTS[key]
+
HSTS_ARGS = ['\"max-age=31536000\"', ' ', 'always']
HEADER_ARGS = {'Strict-Transport-Security': HSTS_ARGS}
diff --git a/certbot-nginx/certbot_nginx/_internal/display_ops.py b/certbot-nginx/certbot_nginx/_internal/display_ops.py
index 95163a81f..89483c94a 100644
--- a/certbot-nginx/certbot_nginx/_internal/display_ops.py
+++ b/certbot-nginx/certbot_nginx/_internal/display_ops.py
@@ -1,12 +1,17 @@
"""Contains UI methods for Nginx operations."""
import logging
+from typing import Iterable
+from typing import List
+from typing import Optional
+
+from certbot_nginx._internal.obj import VirtualHost
from certbot.display import util as display_util
logger = logging.getLogger(__name__)
-def select_vhost_multiple(vhosts):
+def select_vhost_multiple(vhosts: Optional[Iterable[VirtualHost]]) -> List[VirtualHost]:
"""Select multiple Vhosts to install the certificate for
:param vhosts: Available Nginx VirtualHosts
:type vhosts: :class:`list` of type `~obj.Vhost`
@@ -28,7 +33,7 @@ def select_vhost_multiple(vhosts):
return []
-def _reversemap_vhosts(names, vhosts):
+def _reversemap_vhosts(names: Iterable[str], vhosts: Iterable[VirtualHost]) -> List[VirtualHost]:
"""Helper function for select_vhost_multiple for mapping string
representations back to actual vhost objects"""
return_vhosts = []
diff --git a/certbot-nginx/certbot_nginx/_internal/http_01.py b/certbot-nginx/certbot_nginx/_internal/http_01.py
index 95592fea0..6f61bfb6f 100644
--- a/certbot-nginx/certbot_nginx/_internal/http_01.py
+++ b/certbot-nginx/certbot_nginx/_internal/http_01.py
@@ -2,17 +2,20 @@
import io
import logging
+from typing import Any
from typing import List
from typing import Optional
from typing import TYPE_CHECKING
+from certbot_nginx._internal import nginxparser
+from certbot_nginx._internal.obj import Addr
+
from acme import challenges
-from certbot import achallenges
+from acme.challenges import HTTP01Response
from certbot import errors
+from certbot.achallenges import KeyAuthorizationAnnotatedChallenge
from certbot.compat import os
from certbot.plugins import common
-from certbot_nginx._internal import nginxparser
-from certbot_nginx._internal import obj
if TYPE_CHECKING:
from certbot_nginx._internal.configurator import NginxConfigurator
@@ -46,7 +49,7 @@ class NginxHttp01(common.ChallengePerformer):
self.challenge_conf = os.path.join(
configurator.config.config_dir, "le_http_01_cert_challenge.conf")
- def perform(self):
+ def perform(self) -> List[HTTP01Response]:
"""Perform a challenge on Nginx.
:returns: list of :class:`certbot.acme.challenges.HTTP01Response`
@@ -66,7 +69,7 @@ class NginxHttp01(common.ChallengePerformer):
return responses
- def _mod_config(self):
+ def _mod_config(self) -> None:
"""Modifies Nginx config to include server_names_hash_bucket_size directive
and server challenge blocks.
@@ -113,39 +116,40 @@ class NginxHttp01(common.ChallengePerformer):
with io.open(self.challenge_conf, "w", encoding="utf-8") as new_conf:
nginxparser.dump(config, new_conf)
- def _default_listen_addresses(self):
+ def _default_listen_addresses(self) -> List[Addr]:
"""Finds addresses for a challenge block to listen on.
:returns: list of :class:`certbot_nginx._internal.obj.Addr` to apply
:rtype: list
"""
- addresses: List[obj.Addr] = []
+ addresses: List[Optional[Addr]] = []
default_addr = "%s" % self.configurator.config.http01_port
ipv6_addr = "[::]:{0}".format(
self.configurator.config.http01_port)
port = self.configurator.config.http01_port
- ipv6, ipv6only = self.configurator.ipv6_info(port)
+ ipv6, ipv6only = self.configurator.ipv6_info(str(port))
if ipv6:
# If IPv6 is active in Nginx configuration
if not ipv6only:
# If ipv6only=on is not already present in the config
ipv6_addr = ipv6_addr + " ipv6only=on"
- addresses = [obj.Addr.fromstring(default_addr),
- obj.Addr.fromstring(ipv6_addr)]
+ addresses = [Addr.fromstring(default_addr),
+ Addr.fromstring(ipv6_addr)]
logger.debug(("Using default addresses %s and %s for authentication."),
default_addr,
ipv6_addr)
else:
- addresses = [obj.Addr.fromstring(default_addr)]
+ addresses = [Addr.fromstring(default_addr)]
logger.debug("Using default address %s for authentication.",
default_addr)
- return addresses
- def _get_validation_path(self, achall):
+ return [address for address in addresses if address]
+
+ def _get_validation_path(self, achall: KeyAuthorizationAnnotatedChallenge) -> str:
return os.sep + os.path.join(challenges.HTTP01.URI_ROOT_PATH, achall.chall.encode("token"))
- def _make_server_block(self, achall: achallenges.KeyAuthorizationAnnotatedChallenge) -> List:
+ def _make_server_block(self, achall: KeyAuthorizationAnnotatedChallenge) -> List[Any]:
"""Creates a server block for a challenge.
:param achall: Annotated HTTP-01 challenge
@@ -168,7 +172,8 @@ class NginxHttp01(common.ChallengePerformer):
# TODO: do we want to return something else if they otherwise access this block?
return [['server'], block]
- def _location_directive_for_achall(self, achall):
+ def _location_directive_for_achall(self, achall: KeyAuthorizationAnnotatedChallenge
+ ) -> List[Any]:
validation = achall.validation(achall.account_key)
validation_path = self._get_validation_path(achall)
@@ -177,9 +182,8 @@ class NginxHttp01(common.ChallengePerformer):
['return', ' ', '200', ' ', validation]]]
return location_directive
-
- def _make_or_mod_server_block(self, achall: achallenges.KeyAuthorizationAnnotatedChallenge
- ) -> Optional[List]:
+ def _make_or_mod_server_block(self, achall: KeyAuthorizationAnnotatedChallenge
+ ) -> Optional[List[Any]]:
"""Modifies server blocks to respond to a challenge. Returns a new HTTP server block
to add to the configuration if an existing one can't be found.
@@ -192,7 +196,7 @@ class NginxHttp01(common.ChallengePerformer):
"""
http_vhosts, https_vhosts = self.configurator.choose_auth_vhosts(achall.domain)
- new_vhost: Optional[list] = None
+ new_vhost: Optional[List[Any]] = None
if not http_vhosts:
# Couldn't find either a matching name+port server block
# or a port+default_server block, so create a dummy block
@@ -205,8 +209,8 @@ class NginxHttp01(common.ChallengePerformer):
self.configurator.parser.add_server_directives(vhost, location_directive)
rewrite_directive = [['rewrite', ' ', '^(/.well-known/acme-challenge/.*)',
- ' ', '$1', ' ', 'break']]
- self.configurator.parser.add_server_directives(vhost,
- rewrite_directive, insert_at_top=True)
+ ' ', '$1', ' ', 'break']]
+ self.configurator.parser.add_server_directives(
+ vhost, rewrite_directive, insert_at_top=True)
return new_vhost
diff --git a/certbot-nginx/certbot_nginx/_internal/nginxparser.py b/certbot-nginx/certbot_nginx/_internal/nginxparser.py
index 03aa88db4..70a55be3a 100644
--- a/certbot-nginx/certbot_nginx/_internal/nginxparser.py
+++ b/certbot-nginx/certbot_nginx/_internal/nginxparser.py
@@ -2,14 +2,23 @@
# Forked from https://github.com/fatiherikli/nginxparser (MIT Licensed)
import copy
import logging
+import typing
from typing import Any
from typing import IO
+from typing import Iterable
+from typing import Iterator
+from typing import List
+from typing import overload
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
from pyparsing import Combine
from pyparsing import Forward
from pyparsing import Group
from pyparsing import Literal
from pyparsing import Optional
+from pyparsing import ParseResults
from pyparsing import QuotedString
from pyparsing import Regex
from pyparsing import restOfLine
@@ -17,6 +26,9 @@ from pyparsing import stringEnd
from pyparsing import White
from pyparsing import ZeroOrMore
+if TYPE_CHECKING:
+ from typing_extensions import SupportsIndex # typing.SupportsIndex not supported on Python 3.6
+
logger = logging.getLogger(__name__)
@@ -59,23 +71,24 @@ class RawNginxParser:
script = ZeroOrMore(contents) + space + stringEnd
script.parseWithTabs().leaveWhitespace()
- def __init__(self, source):
+ def __init__(self, source: str) -> None:
self.source = source
- def parse(self):
+ def parse(self) -> ParseResults:
"""Returns the parsed tree."""
return self.script.parseString(self.source)
- def as_list(self):
+ def as_list(self) -> List[Any]:
"""Returns the parsed tree as a list."""
return self.parse().asList()
+
class RawNginxDumper:
"""A class that dumps nginx configuration from the provided tree."""
- def __init__(self, blocks):
+ def __init__(self, blocks: List[Any]) -> None:
self.blocks = blocks
- def __iter__(self, blocks=None):
+ def __iter__(self, blocks: typing.Optional[List[Any]] = None) -> Iterator[str]:
"""Iterates the dumped nginx content."""
blocks = blocks or self.blocks
for b0 in blocks:
@@ -100,7 +113,7 @@ class RawNginxDumper:
semicolon = ""
yield "".join(item) + semicolon
- def __str__(self):
+ def __str__(self) -> str:
"""Return the parsed block as a string."""
return ''.join(self)
@@ -108,10 +121,10 @@ class RawNginxDumper:
spacey = lambda x: (isinstance(x, str) and x.isspace()) or x == ''
-class UnspacedList(list):
+class UnspacedList(List[Any]):
"""Wrap a list [of lists], making any whitespace entries magically invisible"""
- def __init__(self, list_source):
+ def __init__(self, list_source: Iterable[Any]) -> None:
# ensure our argument is not a generator, and duplicate any sublists
self.spaced = copy.deepcopy(list(list_source))
self.dirty = False
@@ -122,14 +135,23 @@ class UnspacedList(list):
for i, entry in reversed(list(enumerate(self))):
if isinstance(entry, list):
sublist = UnspacedList(entry)
- list.__setitem__(self, i, sublist)
+ super().__setitem__(i, sublist)
self.spaced[i] = sublist.spaced
elif spacey(entry):
# don't delete comments
if "#" not in self[:i]:
- list.__delitem__(self, i)
+ super().__delitem__(i)
+
+ @overload
+ def _coerce(self, inbound: None) -> Tuple[None, None]: ...
- def _coerce(self, inbound):
+ @overload
+ def _coerce(self, inbound: str) -> Tuple[str, str]: ...
+
+ @overload
+ def _coerce(self, inbound: List[Any]) -> Tuple["UnspacedList", List[Any]]: ...
+
+ def _coerce(self, inbound: Any) -> Tuple[Any, Any]:
"""
Coerce some inbound object to be appropriately usable in this object
@@ -138,100 +160,114 @@ class UnspacedList(list):
:rtype: tuple
"""
- if not isinstance(inbound, list): # str or None
+ if not isinstance(inbound, list): # str or None
return inbound, inbound
else:
if not hasattr(inbound, "spaced"):
inbound = UnspacedList(inbound)
return inbound, inbound.spaced
- def insert(self, i, x):
+ def insert(self, i: int, x: Any) -> None:
+ """Insert object before index."""
item, spaced_item = self._coerce(x)
slicepos = self._spaced_position(i) if i < len(self) else len(self.spaced)
self.spaced.insert(slicepos, spaced_item)
if not spacey(item):
- list.insert(self, i, item)
+ super().insert(i, item)
self.dirty = True
- def append(self, x):
+ def append(self, x: Any) -> None:
+ """Append object to the end of the list."""
item, spaced_item = self._coerce(x)
self.spaced.append(spaced_item)
if not spacey(item):
- list.append(self, item)
+ super().append(item)
self.dirty = True
- def extend(self, x):
+ def extend(self, x: Any) -> None:
+ """Extend list by appending elements from the iterable."""
item, spaced_item = self._coerce(x)
self.spaced.extend(spaced_item)
- list.extend(self, item)
+ super().extend(item)
self.dirty = True
- def __add__(self, other):
- l = copy.deepcopy(self)
- l.extend(other)
- l.dirty = True
- return l
+ def __add__(self, other: List[Any]) -> "UnspacedList":
+ new_list = copy.deepcopy(self)
+ new_list.extend(other)
+ new_list.dirty = True
+ return new_list
- def pop(self, _i=None):
+ def pop(self, *args: Any, **kwargs: Any) -> None:
+ """Function pop() is not implemented for UnspacedList"""
raise NotImplementedError("UnspacedList.pop() not yet implemented")
- def remove(self, _):
+
+ def remove(self, *args: Any, **kwargs: Any) -> None:
+ """Function remove() is not implemented for UnspacedList"""
raise NotImplementedError("UnspacedList.remove() not yet implemented")
- def reverse(self):
+
+ def reverse(self) -> None:
+ """Function reverse() is not implemented for UnspacedList"""
raise NotImplementedError("UnspacedList.reverse() not yet implemented")
- def sort(self, _cmp=None, _key=None, _Rev=None):
+
+ def sort(self, *_args: Any, **_kwargs: Any) -> None:
+ """Function sort() is not implemented for UnspacedList"""
raise NotImplementedError("UnspacedList.sort() not yet implemented")
- def __setslice__(self, _i, _j, _newslice):
+
+ def __setslice__(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("Slice operations on UnspacedLists not yet implemented")
- def __setitem__(self, i, value):
+ def __setitem__(self, i: Union["SupportsIndex", slice], value: Any) -> None:
if isinstance(i, slice):
raise NotImplementedError("Slice operations on UnspacedLists not yet implemented")
item, spaced_item = self._coerce(value)
self.spaced.__setitem__(self._spaced_position(i), spaced_item)
if not spacey(item):
- list.__setitem__(self, i, item)
+ super().__setitem__(i, item)
self.dirty = True
- def __delitem__(self, i):
+ def __delitem__(self, i: Union["SupportsIndex", slice]) -> None:
+ if isinstance(i, slice):
+ raise NotImplementedError("Slice operations on UnspacedLists not yet implemented")
self.spaced.__delitem__(self._spaced_position(i))
- list.__delitem__(self, i)
+ super().__delitem__(i)
self.dirty = True
- def __deepcopy__(self, memo):
+ def __deepcopy__(self, memo: Any) -> "UnspacedList":
new_spaced = copy.deepcopy(self.spaced, memo=memo)
- l = UnspacedList(new_spaced)
- l.dirty = self.dirty
- return l
+ new_list = UnspacedList(new_spaced)
+ new_list.dirty = self.dirty
+ return new_list
- def is_dirty(self):
+ def is_dirty(self) -> bool:
"""Recurse through the parse tree to figure out if any sublists are dirty"""
if self.dirty:
return True
return any((isinstance(x, UnspacedList) and x.is_dirty() for x in self))
- def _spaced_position(self, idx):
- "Convert from indexes in the unspaced list to positions in the spaced one"
+ def _spaced_position(self, idx: "SupportsIndex") -> int:
+ """Convert from indexes in the unspaced list to positions in the spaced one"""
+ int_idx = idx.__index__()
pos = spaces = 0
# Normalize indexes like list[-1] etc, and save the result
- if idx < 0:
- idx = len(self) + idx
- if not 0 <= idx < len(self):
+ if int_idx < 0:
+ int_idx = len(self) + int_idx
+ if not 0 <= int_idx < len(self):
raise IndexError("list index out of range")
- idx0 = idx
- # Count the number of spaces in the spaced list before idx in the unspaced one
- while idx != -1:
+ int_idx0 = int_idx
+ # Count the number of spaces in the spaced list before int_idx in the unspaced one
+ while int_idx != -1:
if spacey(self.spaced[pos]):
spaces += 1
else:
- idx -= 1
+ int_idx -= 1
pos += 1
- return idx0 + spaces
+ return int_idx0 + spaces
# Shortcut functions to respect Python's serialization interface
# (like pyyaml, picker or json)
-def loads(source):
+def loads(source: str) -> UnspacedList:
"""Parses from a string.
:param str source: The string to parse
@@ -242,34 +278,34 @@ def loads(source):
return UnspacedList(RawNginxParser(source).as_list())
-def load(_file):
+def load(file_: IO[Any]) -> UnspacedList:
"""Parses from a file.
- :param file _file: The file to parse
+ :param file file_: The file to parse
:returns: The parsed tree
:rtype: list
"""
- return loads(_file.read())
+ return loads(file_.read())
def dumps(blocks: UnspacedList) -> str:
"""Dump to a Unicode string.
- :param UnspacedList block: The parsed tree
+ :param UnspacedList blocks: The parsed tree
:rtype: six.text_type
"""
return str(RawNginxDumper(blocks.spaced))
-def dump(blocks: UnspacedList, _file: IO[Any]) -> None:
+def dump(blocks: UnspacedList, file_: IO[Any]) -> None:
"""Dump to a file.
- :param UnspacedList block: The parsed tree
- :param IO[Any] _file: The file stream to dump to. It must be opened with
+ :param UnspacedList blocks: The parsed tree
+ :param IO[Any] file_: The file stream to dump to. It must be opened with
Unicode encoding.
:rtype: None
"""
- _file.write(dumps(blocks))
+ file_.write(dumps(blocks))
diff --git a/certbot-nginx/certbot_nginx/_internal/obj.py b/certbot-nginx/certbot_nginx/_internal/obj.py
index 44be0e598..1e0dbba1a 100644
--- a/certbot-nginx/certbot_nginx/_internal/obj.py
+++ b/certbot-nginx/certbot_nginx/_internal/obj.py
@@ -1,10 +1,17 @@
"""Module contains classes used by the Nginx Configurator."""
import re
+from typing import Any
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Set
+from typing import Union
from certbot.plugins import common
ADD_HEADER_DIRECTIVE = 'add_header'
+
class Addr(common.Addr):
r"""Represents an Nginx address, i.e. what comes after the 'listen'
directive.
@@ -34,7 +41,8 @@ class Addr(common.Addr):
UNSPECIFIED_IPV4_ADDRESSES = ('', '*', '0.0.0.0')
CANONICAL_UNSPECIFIED_ADDRESS = UNSPECIFIED_IPV4_ADDRESSES[0]
- def __init__(self, host, port, ssl, default, ipv6, ipv6only):
+ def __init__(self, host: str, port: str, ssl: bool, default: bool,
+ ipv6: bool, ipv6only: bool) -> None:
super().__init__((host, port))
self.ssl = ssl
self.default = default
@@ -43,7 +51,7 @@ class Addr(common.Addr):
self.unspecified_address = host in self.UNSPECIFIED_IPV4_ADDRESSES
@classmethod
- def fromstring(cls, str_addr):
+ def fromstring(cls, str_addr: str) -> Optional["Addr"]:
"""Initialize Addr from string."""
parts = str_addr.split(' ')
ssl = False
@@ -94,7 +102,7 @@ class Addr(common.Addr):
return cls(host, port, ssl, default, ipv6, ipv6only)
- def to_string(self, include_default=True):
+ def to_string(self, include_default: bool = True) -> str:
"""Return string representation of Addr"""
parts = ''
if self.tup[0] and self.tup[1]:
@@ -111,18 +119,18 @@ class Addr(common.Addr):
return parts
- def __str__(self):
+ def __str__(self) -> str:
return self.to_string()
- def __repr__(self):
+ def __repr__(self) -> str:
return "Addr(" + self.__str__() + ")"
- def __hash__(self): # pylint: disable=useless-super-delegation
+ def __hash__(self) -> int: # pylint: disable=useless-super-delegation
# Python 3 requires explicit overridden for __hash__
# See certbot-apache/certbot_apache/_internal/obj.py for more information
return super().__hash__()
- def super_eq(self, other):
+ def super_eq(self, other: "Addr") -> bool:
"""Check ip/port equality, with IPv6 support.
"""
# If both addresses got an unspecified address, then make sure the
@@ -134,7 +142,7 @@ class Addr(common.Addr):
other.tup[1]), other.ipv6)
return super().__eq__(other)
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
if isinstance(other, self.__class__):
return (self.super_eq(other) and
self.ssl == other.ssl and
@@ -158,7 +166,8 @@ class VirtualHost:
"""
- def __init__(self, filep, addrs, ssl, enabled, names, raw, path):
+ def __init__(self, filep: str, addrs: Sequence[Addr], ssl: bool, enabled: bool,
+ names: Set[str], raw: List[Any], path: List[int]) -> None:
"""Initialize a VH."""
self.filep = filep
self.addrs = addrs
@@ -168,7 +177,7 @@ class VirtualHost:
self.raw = raw
self.path = path
- def __str__(self):
+ def __str__(self) -> str:
addr_str = ", ".join(str(addr) for addr in sorted(self.addrs, key=str))
# names might be a set, and it has different representations in Python
# 2 and 3. Force it to be a list here for consistent outputs
@@ -179,10 +188,10 @@ class VirtualHost:
"enabled: %s" % (self.filep, addr_str,
list(self.names), self.ssl, self.enabled))
- def __repr__(self):
+ def __repr__(self) -> str:
return "VirtualHost(" + self.__str__().replace("\n", ", ") + ")\n"
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
if isinstance(other, self.__class__):
return (self.filep == other.filep and
sorted(self.addrs, key=str) == sorted(other.addrs, key=str) and
@@ -193,12 +202,12 @@ class VirtualHost:
return False
- def __hash__(self):
+ def __hash__(self) -> int:
return hash((self.filep, tuple(self.path),
tuple(self.addrs), tuple(self.names),
self.ssl, self.enabled))
- def has_header(self, header_name):
+ def has_header(self, header_name: str) -> bool:
"""Determine if this server block has a particular header set.
:param str header_name: The name of the header to check for, e.g.
'Strict-Transport-Security'
@@ -206,7 +215,7 @@ class VirtualHost:
found = _find_directive(self.raw, ADD_HEADER_DIRECTIVE, header_name)
return found is not None
- def contains_list(self, test):
+ def contains_list(self, test: List[Any]) -> bool:
"""Determine if raw server block contains test list at top level
"""
for i in range(0, len(self.raw) - len(test) + 1):
@@ -214,7 +223,7 @@ class VirtualHost:
return True
return False
- def ipv6_enabled(self):
+ def ipv6_enabled(self) -> bool:
"""Return true if one or more of the listen directives in vhost supports
IPv6"""
for a in self.addrs:
@@ -222,7 +231,7 @@ class VirtualHost:
return True
return False
- def ipv4_enabled(self):
+ def ipv4_enabled(self) -> bool:
"""Return true if one or more of the listen directives in vhost are IPv4
only"""
if not self.addrs:
@@ -232,7 +241,7 @@ class VirtualHost:
return True
return False
- def display_repr(self):
+ def display_repr(self) -> str:
"""Return a representation of VHost to be used in dialog"""
return (
"File: {filename}\n"
@@ -244,7 +253,9 @@ class VirtualHost:
names=", ".join(self.names),
https="Yes" if self.ssl else "No"))
-def _find_directive(directives, directive_name, match_content=None):
+
+def _find_directive(directives: Optional[Union[str, List[Any]]], directive_name: str,
+ match_content: Optional[Any] = None) -> Optional[Any]:
"""Find a directive of type directive_name in directives. If match_content is given,
Searches for `match_content` in the directive arguments.
"""
diff --git a/certbot-nginx/certbot_nginx/_internal/parser.py b/certbot-nginx/certbot_nginx/_internal/parser.py
index 83ec4c923..e60d66eb2 100644
--- a/certbot-nginx/certbot_nginx/_internal/parser.py
+++ b/certbot-nginx/certbot_nginx/_internal/parser.py
@@ -5,20 +5,26 @@ import glob
import io
import logging
import re
+from typing import Any
+from typing import Callable
+from typing import cast
from typing import Dict
+from typing import Iterable
from typing import List
+from typing import Mapping
from typing import Optional
+from typing import Sequence
from typing import Set
from typing import Tuple
from typing import Union
+from certbot_nginx._internal import nginxparser
+from certbot_nginx._internal import obj
+from certbot_nginx._internal.nginxparser import UnspacedList
import pyparsing
from certbot import errors
from certbot.compat import os
-from certbot_nginx._internal import nginxparser
-from certbot_nginx._internal import obj
-from certbot_nginx._internal.nginxparser import UnspacedList
logger = logging.getLogger(__name__)
@@ -32,8 +38,8 @@ class NginxParser:
"""
- def __init__(self, root):
- self.parsed: Dict[str, Union[List, nginxparser.UnspacedList]] = {}
+ def __init__(self, root: str) -> None:
+ self.parsed: Dict[str, UnspacedList] = {}
self.root = os.path.abspath(root)
self.config_root = self._find_config_root()
@@ -42,14 +48,14 @@ class NginxParser:
# not enable sites from there.
self.load()
- def load(self):
+ def load(self) -> None:
"""Loads Nginx files into a parsed tree.
"""
self.parsed = {}
self._parse_recursively(self.config_root)
- def _parse_recursively(self, filepath):
+ def _parse_recursively(self, filepath: str) -> None:
"""Parses nginx config files recursively by looking at 'include'
directives inside 'http' and 'server' blocks. Note that this only
reads Nginx files that potentially declare a virtual host.
@@ -77,7 +83,7 @@ class NginxParser:
if _is_include_directive(server_entry):
self._parse_recursively(server_entry[1])
- def abs_path(self, path):
+ def abs_path(self, path: str) -> str:
"""Converts a relative path to an absolute path relative to the root.
Does nothing for paths that are already absolute.
@@ -90,7 +96,7 @@ class NginxParser:
return os.path.normpath(os.path.join(self.root, path))
return os.path.normpath(path)
- def _build_addr_to_ssl(self):
+ def _build_addr_to_ssl(self) -> Dict[Tuple[str, str], bool]:
"""Builds a map from address to whether it listens on ssl in any server block
"""
servers = self._get_raw_servers()
@@ -107,11 +113,11 @@ class NginxParser:
addr_to_ssl[addr_tuple] = addr.ssl or addr_to_ssl[addr_tuple]
return addr_to_ssl
- def _get_raw_servers(self) -> Dict:
+ def _get_raw_servers(self) -> Dict[str, Union[List[Any], UnspacedList]]:
# pylint: disable=cell-var-from-loop
"""Get a map of unparsed all server blocks
"""
- servers: Dict[str, Union[List, nginxparser.UnspacedList]] = {}
+ servers: Dict[str, Union[List[Any], nginxparser.UnspacedList]] = {}
for filename, tree in self.parsed.items():
servers[filename] = []
srv = servers[filename] # workaround undefined loop var in lambdas
@@ -126,7 +132,7 @@ class NginxParser:
servers[filename][i] = (new_server, path)
return servers
- def get_vhosts(self):
+ def get_vhosts(self) -> List[obj.VirtualHost]:
"""Gets list of all 'virtual hosts' found in Nginx configuration.
Technically this is a misnomer because Nginx does not have virtual
hosts, it has 'server blocks'.
@@ -158,7 +164,7 @@ class NginxParser:
return vhosts
- def _update_vhosts_addrs_ssl(self, vhosts):
+ def _update_vhosts_addrs_ssl(self, vhosts: Iterable[obj.VirtualHost]) -> None:
"""Update a list of raw parsed vhosts to include global address sslishness
"""
addr_to_ssl = self._build_addr_to_ssl()
@@ -168,7 +174,7 @@ class NginxParser:
if addr.ssl:
vhost.ssl = True
- def _get_included_directives(self, block):
+ def _get_included_directives(self, block: UnspacedList) -> UnspacedList:
"""Returns array with the "include" directives expanded out by
concatenating the contents of the included file to the block.
@@ -188,7 +194,7 @@ class NginxParser:
pass
return result
- def _parse_files(self, filepath, override=False):
+ def _parse_files(self, filepath: str, override: bool = False) -> List[UnspacedList]:
"""Parse files from a glob
:param str filepath: Nginx config file path
@@ -219,7 +225,7 @@ class NginxParser:
logger.warning("Could not parse file: %s due to %s", item, err)
return trees
- def _find_config_root(self):
+ def _find_config_root(self) -> str:
"""Return the Nginx Configuration Root file."""
location = ['nginx.conf']
@@ -230,7 +236,7 @@ class NginxParser:
raise errors.NoInstallationError(
"Could not find Nginx root configuration file (nginx.conf)")
- def filedump(self, ext='tmp', lazy=True):
+ def filedump(self, ext: str = 'tmp', lazy: bool = True) -> None:
"""Dumps parsed configurations into files.
:param str ext: The file extension to use for the dumped files. If
@@ -255,7 +261,7 @@ class NginxParser:
except IOError:
logger.error("Could not open file for writing: %s", filename)
- def parse_server(self, server):
+ def parse_server(self, server: UnspacedList) -> Dict[str, Any]:
"""Parses a list of server directives, accounting for global address sslishness.
:param list server: list of directives in a server block
@@ -266,7 +272,7 @@ class NginxParser:
_apply_global_addr_ssl(addr_to_ssl, parsed_server)
return parsed_server
- def has_ssl_on_directive(self, vhost):
+ def has_ssl_on_directive(self, vhost: obj.VirtualHost) -> bool:
"""Does vhost have ssl on for all ports?
:param :class:`~certbot_nginx._internal.obj.VirtualHost` vhost: The vhost in question
@@ -284,7 +290,8 @@ class NginxParser:
return False
- def add_server_directives(self, vhost, directives, insert_at_top=False):
+ def add_server_directives(self, vhost: obj.VirtualHost, directives: List[Any],
+ insert_at_top: bool = False) -> None:
"""Add directives to the server block identified by vhost.
This method modifies vhost to be fully consistent with the new directives.
@@ -305,7 +312,8 @@ class NginxParser:
self._modify_server_directives(vhost,
functools.partial(_add_directives, directives, insert_at_top))
- def update_or_add_server_directives(self, vhost, directives, insert_at_top=False):
+ def update_or_add_server_directives(self, vhost: obj.VirtualHost, directives: List[Any],
+ insert_at_top: bool = False) -> None:
"""Add or replace directives in the server block identified by vhost.
This method modifies vhost to be fully consistent with the new directives.
@@ -327,7 +335,8 @@ class NginxParser:
self._modify_server_directives(vhost,
functools.partial(_update_or_add_directives, directives, insert_at_top))
- def remove_server_directives(self, vhost, directive_name, match_func=None):
+ def remove_server_directives(self, vhost: obj.VirtualHost, directive_name: str,
+ match_func: Optional[Callable[[Any], bool]] = None) -> None:
"""Remove all directives of type directive_name.
:param :class:`~certbot_nginx._internal.obj.VirtualHost` vhost: The vhost
@@ -339,7 +348,8 @@ class NginxParser:
self._modify_server_directives(vhost,
functools.partial(_remove_directives, directive_name, match_func))
- def _update_vhost_based_on_new_directives(self, vhost, directives_list):
+ def _update_vhost_based_on_new_directives(self, vhost: obj.VirtualHost,
+ directives_list: UnspacedList) -> None:
new_server = self._get_included_directives(directives_list)
parsed_server = self.parse_server(new_server)
vhost.addrs = parsed_server['addrs']
@@ -347,7 +357,8 @@ class NginxParser:
vhost.names = parsed_server['names']
vhost.raw = new_server
- def _modify_server_directives(self, vhost, block_func):
+ def _modify_server_directives(self, vhost: obj.VirtualHost,
+ block_func: Callable[[List[Any]], None]) -> None:
filename = vhost.filep
try:
result = self.parsed[filename]
@@ -364,7 +375,7 @@ class NginxParser:
def duplicate_vhost(self, vhost_template: obj.VirtualHost,
remove_singleton_listen_params: bool = False,
- only_directives: Optional[List] = None) -> obj.VirtualHost:
+ only_directives: Optional[List[Any]] = None) -> obj.VirtualHost:
"""Duplicate the vhost in the configuration files.
:param :class:`~certbot_nginx._internal.obj.VirtualHost` vhost_template: The vhost
@@ -417,7 +428,7 @@ class NginxParser:
return new_vhost
-def _parse_ssl_options(ssl_options):
+def _parse_ssl_options(ssl_options: Optional[str]) -> List[UnspacedList]:
if ssl_options is not None:
try:
with io.open(ssl_options, "r", encoding="utf-8") as _file:
@@ -429,9 +440,12 @@ def _parse_ssl_options(ssl_options):
"Only UTF-8 encoding is supported.", ssl_options)
except pyparsing.ParseBaseException as err:
logger.warning("Could not parse file: %s due to %s", ssl_options, err)
- return []
+ return UnspacedList([])
+
-def _do_for_subarray(entry, condition, func, path=None):
+def _do_for_subarray(entry: List[Any], condition: Callable[[List[Any]], bool],
+ func: Callable[[List[Any], List[int]], None],
+ path: Optional[List[int]] = None) -> None:
"""Executes a function for a subarray of a nested array if it matches
the given condition.
@@ -450,7 +464,7 @@ def _do_for_subarray(entry, condition, func, path=None):
_do_for_subarray(item, condition, func, path + [index])
-def get_best_match(target_name, names):
+def get_best_match(target_name: str, names: Iterable[str]) -> Tuple[Optional[str], Optional[str]]:
"""Finds the best match for target_name out of names using the Nginx
name-matching rules (exact > longest wildcard starting with * >
longest wildcard ending with * > regex).
@@ -479,29 +493,29 @@ def get_best_match(target_name, names):
if exact:
# There can be more than one exact match; e.g. eff.org, .eff.org
match = min(exact, key=len)
- return ('exact', match)
+ return 'exact', match
if wildcard_start:
# Return the longest wildcard
match = max(wildcard_start, key=len)
- return ('wildcard_start', match)
+ return 'wildcard_start', match
if wildcard_end:
# Return the longest wildcard
match = max(wildcard_end, key=len)
- return ('wildcard_end', match)
+ return 'wildcard_end', match
if regex:
# Just return the first one for now
match = regex[0]
- return ('regex', match)
+ return 'regex', match
- return (None, None)
+ return None, None
-def _exact_match(target_name, name):
+def _exact_match(target_name: str, name: str) -> bool:
target_lower = target_name.lower()
return name.lower() in (target_lower, '.' + target_lower)
-def _wildcard_match(target_name, name, start):
+def _wildcard_match(target_name: str, name: str, start: bool) -> bool:
# Degenerate case
if name == '*':
return True
@@ -526,7 +540,7 @@ def _wildcard_match(target_name, name, start):
return target_name_lower.endswith('.' + name_lower)
-def _regex_match(target_name, name):
+def _regex_match(target_name: str, name: str) -> bool:
# Must start with a tilde
if len(name) < 2 or name[0] != '~':
return False
@@ -534,13 +548,13 @@ def _regex_match(target_name, name):
# After tilde is a perl-compatible regex
try:
regex = re.compile(name[1:])
- return re.match(regex, target_name)
+ return bool(re.match(regex, target_name))
except re.error: # pragma: no cover
# perl-compatible regexes are sometimes not recognized by python
return False
-def _is_include_directive(entry):
+def _is_include_directive(entry: Any) -> bool:
"""Checks if an nginx parsed entry is an 'include' directive.
:param list entry: the parsed entry
@@ -552,7 +566,8 @@ def _is_include_directive(entry):
len(entry) == 2 and entry[0] == 'include' and
isinstance(entry[1], str))
-def _is_ssl_on_directive(entry):
+
+def _is_ssl_on_directive(entry: Any) -> bool:
"""Checks if an nginx parsed entry is an 'ssl on' directive.
:param list entry: the parsed entry
@@ -564,14 +579,18 @@ def _is_ssl_on_directive(entry):
len(entry) == 2 and entry[0] == 'ssl' and
entry[1] == 'on')
-def _add_directives(directives, insert_at_top, block):
+
+def _add_directives(directives: List[Any], insert_at_top: bool,
+ block: UnspacedList) -> None:
"""Adds directives to a config block."""
for directive in directives:
_add_directive(block, directive, insert_at_top)
if block and '\n' not in block[-1]: # could be " \n " or ["\n"] !
block.append(nginxparser.UnspacedList('\n'))
-def _update_or_add_directives(directives, insert_at_top, block):
+
+def _update_or_add_directives(directives: List[Any], insert_at_top: bool,
+ block: UnspacedList) -> None:
"""Adds or replaces directives in a config block."""
for directive in directives:
_update_or_add_directive(block, directive, insert_at_top)
@@ -584,7 +603,8 @@ REPEATABLE_DIRECTIVES = {'server_name', 'listen', INCLUDE, 'rewrite', 'add_heade
COMMENT = ' managed by Certbot'
COMMENT_BLOCK = [' ', '#', COMMENT]
-def comment_directive(block, location):
+
+def comment_directive(block: UnspacedList, location: int) -> None:
"""Add a ``#managed by Certbot`` comment to the end of the line at location.
:param list block: The block containing the directive to be commented
@@ -603,40 +623,45 @@ def comment_directive(block, location):
if next_entry is not None and "\n" not in next_entry:
block.insert(location + 2, '\n')
-def _comment_out_directive(block, location, include_location):
+
+def _comment_out_directive(block: UnspacedList, location: int, include_location: str) -> None:
"""Comment out the line at location, with a note of explanation."""
comment_message = ' duplicated in {0}'.format(include_location)
# add the end comment
# create a dumpable object out of block[location] (so it includes the ;)
directive = block[location]
- new_dir_block = nginxparser.UnspacedList([]) # just a wrapper
+ new_dir_block = nginxparser.UnspacedList([]) # just a wrapper
new_dir_block.append(directive)
dumped = nginxparser.dumps(new_dir_block)
- commented = dumped + ' #' + comment_message # add the comment directly to the one-line string
- new_dir = nginxparser.loads(commented) # reload into UnspacedList
+ commented = dumped + ' #' + comment_message # add the comment directly to the one-line string
+ new_dir = nginxparser.loads(commented) # reload into UnspacedList
# add the beginning comment
insert_location = 0
- if new_dir[0].spaced[0] != new_dir[0][0]: # if there's whitespace at the beginning
+ if new_dir[0].spaced[0] != new_dir[0][0]: # if there's whitespace at the beginning
insert_location = 1
- new_dir[0].spaced.insert(insert_location, "# ") # comment out the line
- new_dir[0].spaced.append(";") # directly add in the ;, because now dumping won't work properly
+ new_dir[0].spaced.insert(insert_location, "# ") # comment out the line
+ new_dir[0].spaced.append(";") # directly add in the ;, because now dumping won't work properly
dumped = nginxparser.dumps(new_dir)
- new_dir = nginxparser.loads(dumped) # reload into an UnspacedList
+ new_dir = nginxparser.loads(dumped) # reload into an UnspacedList
block[location] = new_dir[0] # set the now-single-line-comment directive back in place
-def _find_location(block, directive_name, match_func=None):
+
+def _find_location(block: UnspacedList, directive_name: str,
+ match_func: Optional[Callable[[Any], bool]] = None) -> Optional[int]:
"""Finds the index of the first instance of directive_name in block.
If no line exists, use None."""
- return next((index for index, line in enumerate(block) \
- if line and line[0] == directive_name and (match_func is None or match_func(line))), None)
+ return next((index for index, line in enumerate(block) if (
+ line and line[0] == directive_name and (match_func is None or match_func(line)))), None)
+
-def _is_whitespace_or_comment(directive):
+def _is_whitespace_or_comment(directive: Sequence[Any]) -> bool:
"""Is this directive either a whitespace or comment directive?"""
return len(directive) == 0 or directive[0] == '#'
-def _add_directive(block, directive, insert_at_top):
+
+def _add_directive(block: UnspacedList, directive: Sequence[Any], insert_at_top: bool) -> None:
if not isinstance(directive, nginxparser.UnspacedList):
directive = nginxparser.UnspacedList(directive)
if _is_whitespace_or_comment(directive):
@@ -653,10 +678,11 @@ def _add_directive(block, directive, insert_at_top):
# handle flat include files
directive_name = directive[0]
- def can_append(loc, dir_name):
+
+ def can_append(loc: Optional[int], dir_name: str) -> bool:
""" Can we append this directive to the block? """
return loc is None or (isinstance(dir_name, str)
- and dir_name in REPEATABLE_DIRECTIVES)
+ and dir_name in REPEATABLE_DIRECTIVES)
err_fmt = 'tried to insert directive "{0}" but found conflicting "{1}".'
@@ -672,10 +698,14 @@ def _add_directive(block, directive, insert_at_top):
included_dir_name = included_directive[0]
if (not _is_whitespace_or_comment(included_directive)
and not can_append(included_dir_loc, included_dir_name)):
- if block[included_dir_loc] != included_directive:
- raise errors.MisconfigurationError(err_fmt.format(included_directive,
- block[included_dir_loc]))
- _comment_out_directive(block, included_dir_loc, directive[1])
+
+ # By construction of can_append(), included_dir_loc cannot be None at that point
+ resolved_included_dir_loc = cast(int, included_dir_loc)
+
+ if block[resolved_included_dir_loc] != included_directive:
+ raise errors.MisconfigurationError(err_fmt.format(
+ included_directive, block[resolved_included_dir_loc]))
+ _comment_out_directive(block, resolved_included_dir_loc, directive[1])
if can_append(location, directive_name):
if insert_at_top:
@@ -687,14 +717,22 @@ def _add_directive(block, directive, insert_at_top):
else:
block.append(directive)
comment_directive(block, len(block) - 1)
- elif block[location] != directive:
- raise errors.MisconfigurationError(err_fmt.format(directive, block[location]))
+ return
+
+ # By construction of can_append(), location cannot be None at that point
+ resolved_location = cast(int, location)
+
+ if block[resolved_location] != directive:
+ raise errors.MisconfigurationError(err_fmt.format(directive, block[resolved_location]))
-def _update_directive(block, directive, location):
+
+def _update_directive(block: UnspacedList, directive: Sequence[Any], location: int) -> None:
block[location] = directive
comment_directive(block, location)
-def _update_or_add_directive(block, directive, insert_at_top):
+
+def _update_or_add_directive(block: UnspacedList, directive: Sequence[Any],
+ insert_at_top: bool) -> None:
if not isinstance(directive, nginxparser.UnspacedList):
directive = nginxparser.UnspacedList(directive)
if _is_whitespace_or_comment(directive):
@@ -711,10 +749,13 @@ def _update_or_add_directive(block, directive, insert_at_top):
_add_directive(block, directive, insert_at_top)
-def _is_certbot_comment(directive):
+
+def _is_certbot_comment(directive: Sequence[Any]) -> bool:
return '#' in directive and COMMENT in directive
-def _remove_directives(directive_name, match_func, block):
+
+def _remove_directives(directive_name: str, match_func: Callable[[Any], bool],
+ block: UnspacedList) -> None:
"""Removes directives of name directive_name from a config block if match_func matches.
"""
while True:
@@ -726,7 +767,9 @@ def _remove_directives(directive_name, match_func, block):
del block[location + 1]
del block[location]
-def _apply_global_addr_ssl(addr_to_ssl, parsed_server):
+
+def _apply_global_addr_ssl(addr_to_ssl: Mapping[Tuple[str, str], bool],
+ parsed_server: Dict[str, Any]) -> None:
"""Apply global sslishness information to the parsed server block
"""
for addr in parsed_server['addrs']:
@@ -734,7 +777,8 @@ def _apply_global_addr_ssl(addr_to_ssl, parsed_server):
if addr.ssl:
parsed_server['ssl'] = True
-def _parse_server_raw(server):
+
+def _parse_server_raw(server: UnspacedList) -> Dict[str, Any]:
"""Parses a list of server directives.
:param list server: list of directives in a server block
diff --git a/certbot-nginx/certbot_nginx/_internal/parser_obj.py b/certbot-nginx/certbot_nginx/_internal/parser_obj.py
index 33ed822c3..0af38a936 100644
--- a/certbot-nginx/certbot_nginx/_internal/parser_obj.py
+++ b/certbot-nginx/certbot_nginx/_internal/parser_obj.py
@@ -5,7 +5,14 @@ raw lists of tokens from pyparsing. """
import abc
import logging
+from typing import Any
+from typing import Callable
+from typing import Iterator
from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Type
from certbot import errors
@@ -23,24 +30,24 @@ class Parsable:
__metaclass__ = abc.ABCMeta
- def __init__(self, parent=None):
- self._data: List[object] = []
+ def __init__(self, parent: Optional["Parsable"] = None):
+ self._data: List[Any] = []
self._tabs = None
self.parent = parent
@classmethod
- def parsing_hooks(cls):
+ def parsing_hooks(cls) -> Tuple[Type["Block"], Type["Sentence"], Type["Statements"]]:
"""Returns object types that this class should be able to `parse` recusrively.
The order of the objects indicates the order in which the parser should
try to parse each subitem.
:returns: A list of Parsable classes.
:rtype list:
"""
- return (Block, Sentence, Statements)
+ return Block, Sentence, Statements
@staticmethod
@abc.abstractmethod
- def should_parse(lists):
+ def should_parse(lists: Any) -> bool:
""" Returns whether the contents of `lists` can be parsed into this object.
:returns: Whether `lists` can be parsed as this object.
@@ -49,7 +56,7 @@ class Parsable:
raise NotImplementedError()
@abc.abstractmethod
- def parse(self, raw_list, add_spaces=False):
+ def parse(self, raw_list: List[Any], add_spaces: bool = False) -> None:
""" Loads information into this object from underlying raw_list structure.
Each Parsable object might make different assumptions about the structure of
raw_list.
@@ -64,7 +71,8 @@ class Parsable:
raise NotImplementedError()
@abc.abstractmethod
- def iterate(self, expanded=False, match=None):
+ def iterate(self, expanded: bool = False,
+ match: Optional[Callable[["Parsable"], bool]] = None) -> Iterator[Any]:
""" Iterates across this object. If this object is a leaf object, only yields
itself. If it contains references other parsing objects, and `expanded` is set,
this function should first yield itself, then recursively iterate across all of them.
@@ -77,7 +85,7 @@ class Parsable:
raise NotImplementedError()
@abc.abstractmethod
- def get_tabs(self):
+ def get_tabs(self) -> str:
""" Guess at the tabbing style of this parsed object, based on whitespace.
If this object is a leaf, it deducts the tabbing based on its own contents.
@@ -90,7 +98,7 @@ class Parsable:
raise NotImplementedError()
@abc.abstractmethod
- def set_tabs(self, tabs=" "):
+ def set_tabs(self, tabs: str = " ") -> None:
"""This tries to set and alter the tabbing of the current object to a desired
whitespace string. Primarily meant for objects that were constructed, so they
can conform to surrounding whitespace.
@@ -99,7 +107,7 @@ class Parsable:
"""
raise NotImplementedError()
- def dump(self, include_spaces=False):
+ def dump(self, include_spaces: bool = False) -> List[Any]:
""" Dumps back to pyparsing-like list tree. The opposite of `parse`.
Note: if this object has not been modified, `dump` with `include_spaces=True`
@@ -121,17 +129,17 @@ class Statements(Parsable):
an extra `_trailing_whitespace` string to keep track of the whitespace that does not
precede any more statements.
"""
- def __init__(self, parent=None):
+ def __init__(self, parent: Optional[Parsable] = None):
super().__init__(parent)
self._trailing_whitespace = None
# ======== Begin overridden functions
@staticmethod
- def should_parse(lists):
+ def should_parse(lists: Any) -> bool:
return isinstance(lists, list)
- def set_tabs(self, tabs=" "):
+ def set_tabs(self, tabs: str = " ") -> None:
""" Sets the tabbing for this set of statements. Does this by calling `set_tabs`
on each of the child statements.
@@ -144,7 +152,7 @@ class Statements(Parsable):
if self.parent is not None:
self._trailing_whitespace = "\n" + self.parent.get_tabs()
- def parse(self, raw_list, add_spaces=False):
+ def parse(self, raw_list: List[Any], add_spaces: bool = False) -> None:
""" Parses a list of statements.
Expects all elements in `raw_list` to be parseable by `type(self).parsing_hooks`,
with an optional whitespace string at the last index of `raw_list`.
@@ -157,14 +165,14 @@ class Statements(Parsable):
raw_list = raw_list[:-1]
self._data = [parse_raw(elem, self, add_spaces) for elem in raw_list]
- def get_tabs(self):
+ def get_tabs(self) -> str:
""" Takes a guess at the tabbing of all contained Statements by retrieving the
tabbing of the first Statement."""
if self._data:
return self._data[0].get_tabs()
return ""
- def dump(self, include_spaces=False):
+ def dump(self, include_spaces: bool = False) -> List[Any]:
""" Dumps this object by first dumping each statement, then appending its
trailing whitespace (if `include_spaces` is set) """
data = super().dump(include_spaces)
@@ -172,7 +180,8 @@ class Statements(Parsable):
return data + [self._trailing_whitespace]
return data
- def iterate(self, expanded=False, match=None):
+ def iterate(self, expanded: bool = False,
+ match: Optional[Callable[["Parsable"], bool]] = None) -> Iterator[Any]:
""" Combines each statement's iterator. """
for elem in self._data:
for sub_elem in elem.iterate(expanded, match):
@@ -181,7 +190,7 @@ class Statements(Parsable):
# ======== End overridden functions
-def _space_list(list_):
+def _space_list(list_: Sequence[Any]) -> List[str]:
""" Inserts whitespace between adjacent non-whitespace tokens. """
spaced_statement: List[str] = []
for i in reversed(range(len(list_))):
@@ -197,7 +206,7 @@ class Sentence(Parsable):
# ======== Begin overridden functions
@staticmethod
- def should_parse(lists):
+ def should_parse(lists: Any) -> bool:
""" Returns True if `lists` can be parseable as a `Sentence`-- that is,
every element is a string type.
@@ -205,38 +214,39 @@ class Sentence(Parsable):
:returns: whether this lists is parseable by `Sentence`.
"""
- return isinstance(lists, list) and len(lists) > 0 and \
- all(isinstance(elem, str) for elem in lists)
+ return (isinstance(lists, list) and len(lists) > 0 and
+ all(isinstance(elem, str) for elem in lists))
- def parse(self, raw_list, add_spaces=False):
+ def parse(self, raw_list: List[Any], add_spaces: bool = False) -> None:
""" Parses a list of string types into this object.
If add_spaces is set, adds whitespace tokens between adjacent non-whitespace tokens."""
if add_spaces:
raw_list = _space_list(raw_list)
- if not isinstance(raw_list, list) or \
- any(not isinstance(elem, str) for elem in raw_list):
+ if (not isinstance(raw_list, list)
+ or any(not isinstance(elem, str) for elem in raw_list)):
raise errors.MisconfigurationError("Sentence parsing expects a list of string types.")
self._data = raw_list
- def iterate(self, expanded=False, match=None):
+ def iterate(self, expanded: bool = False,
+ match: Optional[Callable[[Parsable], bool]] = None) -> Iterator[Any]:
""" Simply yields itself. """
if match is None or match(self):
yield self
- def set_tabs(self, tabs=" "):
+ def set_tabs(self, tabs: str = " ") -> None:
""" Sets the tabbing on this sentence. Inserts a newline and `tabs` at the
beginning of `self._data`. """
if self._data[0].isspace():
return
self._data.insert(0, "\n" + tabs)
- def dump(self, include_spaces=False):
+ def dump(self, include_spaces: bool = False) -> List[Any]:
""" Dumps this sentence. If include_spaces is set, includes whitespace tokens."""
if not include_spaces:
return self.words
return self._data
- def get_tabs(self):
+ def get_tabs(self) -> str:
""" Guesses at the tabbing of this sentence. If the first element is whitespace,
returns the whitespace after the rightmost newline in the string. """
first = self._data[0]
@@ -248,14 +258,14 @@ class Sentence(Parsable):
# ======== End overridden functions
@property
- def words(self):
+ def words(self) -> List[str]:
""" Iterates over words, but without spaces. Like Unspaced List. """
return [word.strip("\"\'") for word in self._data if not word.isspace()]
- def __getitem__(self, index):
+ def __getitem__(self, index: int) -> str:
return self.words[index]
- def __contains__(self, word):
+ def __contains__(self, word: str) -> bool:
return word in self.words
@@ -270,13 +280,13 @@ class Block(Parsable):
names = ["block", " ", "name", " "]
contents = [["\n ", "content", " ", "1"], ["\n ", "content", " ", "2"], "\n"]
"""
- def __init__(self, parent=None):
+ def __init__(self, parent: Optional[Parsable] = None) -> None:
super().__init__(parent)
- self.names: Sentence = None
- self.contents: Block = None
+ self.names: Optional[Sentence] = None
+ self.contents: Optional[Block] = None
@staticmethod
- def should_parse(lists):
+ def should_parse(lists: Any) -> bool:
""" Returns True if `lists` can be parseable as a `Block`-- that is,
it's got a length of 2, the first element is a `Sentence` and the second can be
a `Statements`.
@@ -287,13 +297,14 @@ class Block(Parsable):
return isinstance(lists, list) and len(lists) == 2 and \
Sentence.should_parse(lists[0]) and isinstance(lists[1], list)
- def set_tabs(self, tabs=" "):
+ def set_tabs(self, tabs: str = " ") -> None:
""" Sets tabs by setting equivalent tabbing on names, then adding tabbing
to contents."""
self.names.set_tabs(tabs)
self.contents.set_tabs(tabs + " ")
- def iterate(self, expanded=False, match=None):
+ def iterate(self, expanded: bool = False,
+ match: Optional[Callable[[Parsable], bool]] = None) -> Iterator[Any]:
""" Iterator over self, and if expanded is set, over its contents. """
if match is None or match(self):
yield self
@@ -301,7 +312,7 @@ class Block(Parsable):
for elem in self.contents.iterate(expanded, match):
yield elem
- def parse(self, raw_list, add_spaces=False):
+ def parse(self, raw_list: List[Any], add_spaces: bool = False) -> None:
""" Parses a list that resembles a block.
The assumptions that this routine makes are:
@@ -323,11 +334,12 @@ class Block(Parsable):
self.contents.parse(raw_list[1], add_spaces)
self._data = [self.names, self.contents]
- def get_tabs(self):
+ def get_tabs(self) -> str:
""" Guesses tabbing by retrieving tabbing guess of self.names. """
return self.names.get_tabs()
-def _is_comment(parsed_obj):
+
+def _is_comment(parsed_obj: Parsable) -> bool:
""" Checks whether parsed_obj is a comment.
:param .Parsable parsed_obj:
@@ -339,7 +351,8 @@ def _is_comment(parsed_obj):
return False
return parsed_obj.words[0] == "#"
-def _is_certbot_comment(parsed_obj):
+
+def _is_certbot_comment(parsed_obj: Parsable) -> bool:
""" Checks whether parsed_obj is a "managed by Certbot" comment.
:param .Parsable parsed_obj:
@@ -356,7 +369,8 @@ def _is_certbot_comment(parsed_obj):
return False
return True
-def _certbot_comment(parent, preceding_spaces=4):
+
+def _certbot_comment(parent: Parsable, preceding_spaces: int = 4) -> Sentence:
""" A "Managed by Certbot" comment.
:param int preceding_spaces: Number of spaces between the end of the previous
statement and the comment.
@@ -367,7 +381,8 @@ def _certbot_comment(parent, preceding_spaces=4):
result.parse([" " * preceding_spaces] + COMMENT_BLOCK)
return result
-def _choose_parser(parent, list_):
+
+def _choose_parser(parent: Parsable, list_: Any) -> Parsable:
""" Choose a parser from type(parent).parsing_hooks, depending on whichever hook
returns True first. """
hooks = Parsable.parsing_hooks()
@@ -379,7 +394,8 @@ def _choose_parser(parent, list_):
raise errors.MisconfigurationError(
"None of the parsing hooks succeeded, so we don't know how to parse this set of lists.")
-def parse_raw(lists_, parent=None, add_spaces=False):
+
+def parse_raw(lists_: Any, parent: Optional[Parsable] = None, add_spaces: bool = False) -> Parsable:
""" Primary parsing factory function.
:param list lists_: raw lists from pyparsing to parse.