diff options
Diffstat (limited to 'acme/tests/client_test.py')
-rw-r--r-- | acme/tests/client_test.py | 1308 |
1 files changed, 1308 insertions, 0 deletions
diff --git a/acme/tests/client_test.py b/acme/tests/client_test.py new file mode 100644 index 000000000..22eb3fc45 --- /dev/null +++ b/acme/tests/client_test.py @@ -0,0 +1,1308 @@ +"""Tests for acme.client.""" +# pylint: disable=too-many-lines +import copy +import datetime +import json +import unittest + +from six.moves import http_client # pylint: disable=import-error + +import josepy as jose +import mock +import OpenSSL +import requests + +from acme import challenges +from acme import errors +from acme import jws as acme_jws +from acme import messages +from acme.magic_typing import Dict # pylint: disable=unused-import, no-name-in-module + +import messages_test +import test_util + +CERT_DER = test_util.load_vector('cert.der') +CERT_SAN_PEM = test_util.load_vector('cert-san.pem') +CSR_SAN_PEM = test_util.load_vector('csr-san.pem') +KEY = jose.JWKRSA.load(test_util.load_vector('rsa512_key.pem')) +KEY2 = jose.JWKRSA.load(test_util.load_vector('rsa256_key.pem')) + +DIRECTORY_V1 = messages.Directory({ + messages.NewRegistration: + 'https://www.letsencrypt-demo.org/acme/new-reg', + messages.Revocation: + 'https://www.letsencrypt-demo.org/acme/revoke-cert', + messages.NewAuthorization: + 'https://www.letsencrypt-demo.org/acme/new-authz', + messages.CertificateRequest: + 'https://www.letsencrypt-demo.org/acme/new-cert', +}) + +DIRECTORY_V2 = messages.Directory({ + 'newAccount': 'https://www.letsencrypt-demo.org/acme/new-account', + 'newNonce': 'https://www.letsencrypt-demo.org/acme/new-nonce', + 'newOrder': 'https://www.letsencrypt-demo.org/acme/new-order', + 'revokeCert': 'https://www.letsencrypt-demo.org/acme/revoke-cert', +}) + + +class ClientTestBase(unittest.TestCase): + """Base for tests in acme.client.""" + + def setUp(self): + self.response = mock.MagicMock( + ok=True, status_code=http_client.OK, headers={}, links={}) + self.net = mock.MagicMock() + self.net.post.return_value = self.response + self.net.get.return_value = self.response + + self.identifier = messages.Identifier( + typ=messages.IDENTIFIER_FQDN, value='example.com') + + # Registration + self.contact = ('mailto:cert-admin@example.com', 'tel:+12025551212') + reg = messages.Registration( + contact=self.contact, key=KEY.public_key()) + the_arg = dict(reg) # type: Dict + self.new_reg = messages.NewRegistration(**the_arg) + self.regr = messages.RegistrationResource( + body=reg, uri='https://www.letsencrypt-demo.org/acme/reg/1') + + # Authorization + authzr_uri = 'https://www.letsencrypt-demo.org/acme/authz/1' + challb = messages.ChallengeBody( + uri=(authzr_uri + '/1'), status=messages.STATUS_VALID, + chall=challenges.DNS(token=jose.b64decode( + 'evaGxfADs6pSRb2LAv9IZf17Dt3juxGJ-PCt92wr-oA'))) + self.challr = messages.ChallengeResource( + body=challb, authzr_uri=authzr_uri) + self.authz = messages.Authorization( + identifier=messages.Identifier( + typ=messages.IDENTIFIER_FQDN, value='example.com'), + challenges=(challb,), combinations=None) + self.authzr = messages.AuthorizationResource( + body=self.authz, uri=authzr_uri) + + # Reason code for revocation + self.rsn = 1 + + +class BackwardsCompatibleClientV2Test(ClientTestBase): + """Tests for acme.client.BackwardsCompatibleClientV2.""" + + def setUp(self): + super(BackwardsCompatibleClientV2Test, self).setUp() + # contains a loaded cert + self.certr = messages.CertificateResource( + body=messages_test.CERT) + + loaded = OpenSSL.crypto.load_certificate( + OpenSSL.crypto.FILETYPE_PEM, CERT_SAN_PEM) + wrapped = jose.ComparableX509(loaded) + self.chain = [wrapped, wrapped] + + self.cert_pem = OpenSSL.crypto.dump_certificate( + OpenSSL.crypto.FILETYPE_PEM, messages_test.CERT.wrapped).decode() + + single_chain = OpenSSL.crypto.dump_certificate( + OpenSSL.crypto.FILETYPE_PEM, loaded).decode() + self.chain_pem = single_chain + single_chain + + self.fullchain_pem = self.cert_pem + self.chain_pem + + self.orderr = messages.OrderResource( + csr_pem=CSR_SAN_PEM) + + def _init(self): + uri = 'http://www.letsencrypt-demo.org/directory' + from acme.client import BackwardsCompatibleClientV2 + return BackwardsCompatibleClientV2(net=self.net, + key=KEY, server=uri) + + def test_init_downloads_directory(self): + uri = 'http://www.letsencrypt-demo.org/directory' + from acme.client import BackwardsCompatibleClientV2 + BackwardsCompatibleClientV2(net=self.net, + key=KEY, server=uri) + self.net.get.assert_called_once_with(uri) + + def test_init_acme_version(self): + self.response.json.return_value = DIRECTORY_V1.to_json() + client = self._init() + self.assertEqual(client.acme_version, 1) + + self.response.json.return_value = DIRECTORY_V2.to_json() + client = self._init() + self.assertEqual(client.acme_version, 2) + + def test_query_registration_client_v2(self): + self.response.json.return_value = DIRECTORY_V2.to_json() + client = self._init() + self.response.json.return_value = self.regr.body.to_json() + self.assertEqual(self.regr, client.query_registration(self.regr)) + + def test_forwarding(self): + self.response.json.return_value = DIRECTORY_V1.to_json() + client = self._init() + self.assertEqual(client.directory, client.client.directory) + self.assertEqual(client.key, KEY) + self.assertEqual(client.deactivate_registration, client.client.deactivate_registration) + self.assertRaises(AttributeError, client.__getattr__, 'nonexistent') + self.assertRaises(AttributeError, client.__getattr__, 'new_account_and_tos') + self.assertRaises(AttributeError, client.__getattr__, 'new_account') + + def test_new_account_and_tos(self): + # v2 no tos + self.response.json.return_value = DIRECTORY_V2.to_json() + with mock.patch('acme.client.ClientV2') as mock_client: + client = self._init() + client.new_account_and_tos(self.new_reg) + mock_client().new_account.assert_called_with(self.new_reg) + + # v2 tos good + with mock.patch('acme.client.ClientV2') as mock_client: + mock_client().directory.meta.__contains__.return_value = True + client = self._init() + client.new_account_and_tos(self.new_reg, lambda x: True) + mock_client().new_account.assert_called_with( + self.new_reg.update(terms_of_service_agreed=True)) + + # v2 tos bad + with mock.patch('acme.client.ClientV2') as mock_client: + mock_client().directory.meta.__contains__.return_value = True + client = self._init() + def _tos_cb(tos): + raise errors.Error + self.assertRaises(errors.Error, client.new_account_and_tos, + self.new_reg, _tos_cb) + mock_client().new_account.assert_not_called() + + # v1 yes tos + self.response.json.return_value = DIRECTORY_V1.to_json() + with mock.patch('acme.client.Client') as mock_client: + regr = mock.MagicMock(terms_of_service="TOS") + mock_client().register.return_value = regr + client = self._init() + client.new_account_and_tos(self.new_reg) + mock_client().register.assert_called_once_with(self.new_reg) + mock_client().agree_to_tos.assert_called_once_with(regr) + + # v1 no tos + with mock.patch('acme.client.Client') as mock_client: + regr = mock.MagicMock(terms_of_service=None) + mock_client().register.return_value = regr + client = self._init() + client.new_account_and_tos(self.new_reg) + mock_client().register.assert_called_once_with(self.new_reg) + mock_client().agree_to_tos.assert_not_called() + + @mock.patch('OpenSSL.crypto.load_certificate_request') + @mock.patch('acme.crypto_util._pyopenssl_cert_or_req_all_names') + def test_new_order_v1(self, mock__pyopenssl_cert_or_req_all_names, + unused_mock_load_certificate_request): + self.response.json.return_value = DIRECTORY_V1.to_json() + mock__pyopenssl_cert_or_req_all_names.return_value = ['example.com', 'www.example.com'] + mock_csr_pem = mock.MagicMock() + with mock.patch('acme.client.Client') as mock_client: + mock_client().request_domain_challenges.return_value = mock.sentinel.auth + client = self._init() + orderr = client.new_order(mock_csr_pem) + self.assertEqual(orderr.authorizations, [mock.sentinel.auth, mock.sentinel.auth]) + + def test_new_order_v2(self): + self.response.json.return_value = DIRECTORY_V2.to_json() + mock_csr_pem = mock.MagicMock() + with mock.patch('acme.client.ClientV2') as mock_client: + client = self._init() + client.new_order(mock_csr_pem) + mock_client().new_order.assert_called_once_with(mock_csr_pem) + + @mock.patch('acme.client.Client') + def test_finalize_order_v1_success(self, mock_client): + self.response.json.return_value = DIRECTORY_V1.to_json() + + mock_client().request_issuance.return_value = self.certr + mock_client().fetch_chain.return_value = self.chain + + deadline = datetime.datetime(9999, 9, 9) + client = self._init() + result = client.finalize_order(self.orderr, deadline) + self.assertEqual(result.fullchain_pem, self.fullchain_pem) + mock_client().fetch_chain.assert_called_once_with(self.certr) + + @mock.patch('acme.client.Client') + def test_finalize_order_v1_fetch_chain_error(self, mock_client): + self.response.json.return_value = DIRECTORY_V1.to_json() + + mock_client().request_issuance.return_value = self.certr + mock_client().fetch_chain.return_value = self.chain + mock_client().fetch_chain.side_effect = [errors.Error, self.chain] + + deadline = datetime.datetime(9999, 9, 9) + client = self._init() + result = client.finalize_order(self.orderr, deadline) + self.assertEqual(result.fullchain_pem, self.fullchain_pem) + self.assertEqual(mock_client().fetch_chain.call_count, 2) + + @mock.patch('acme.client.Client') + def test_finalize_order_v1_timeout(self, mock_client): + self.response.json.return_value = DIRECTORY_V1.to_json() + + mock_client().request_issuance.return_value = self.certr + + deadline = deadline = datetime.datetime.now() - datetime.timedelta(seconds=60) + client = self._init() + self.assertRaises(errors.TimeoutError, client.finalize_order, + self.orderr, deadline) + + def test_finalize_order_v2(self): + self.response.json.return_value = DIRECTORY_V2.to_json() + mock_orderr = mock.MagicMock() + mock_deadline = mock.MagicMock() + with mock.patch('acme.client.ClientV2') as mock_client: + client = self._init() + client.finalize_order(mock_orderr, mock_deadline) + mock_client().finalize_order.assert_called_once_with(mock_orderr, mock_deadline) + + def test_revoke(self): + self.response.json.return_value = DIRECTORY_V1.to_json() + with mock.patch('acme.client.Client') as mock_client: + client = self._init() + client.revoke(messages_test.CERT, self.rsn) + mock_client().revoke.assert_called_once_with(messages_test.CERT, self.rsn) + + self.response.json.return_value = DIRECTORY_V2.to_json() + with mock.patch('acme.client.ClientV2') as mock_client: + client = self._init() + client.revoke(messages_test.CERT, self.rsn) + mock_client().revoke.assert_called_once_with(messages_test.CERT, self.rsn) + + def test_update_registration(self): + self.response.json.return_value = DIRECTORY_V1.to_json() + with mock.patch('acme.client.Client') as mock_client: + client = self._init() + client.update_registration(mock.sentinel.regr, None) + mock_client().update_registration.assert_called_once_with(mock.sentinel.regr, None) + + # newNonce present means it will pick acme_version 2 + def test_external_account_required_true(self): + self.response.json.return_value = messages.Directory({ + 'newNonce': 'http://letsencrypt-test.com/acme/new-nonce', + 'meta': messages.Directory.Meta(external_account_required=True), + }).to_json() + + client = self._init() + + self.assertTrue(client.external_account_required()) + + # newNonce present means it will pick acme_version 2 + def test_external_account_required_false(self): + self.response.json.return_value = messages.Directory({ + 'newNonce': 'http://letsencrypt-test.com/acme/new-nonce', + 'meta': messages.Directory.Meta(external_account_required=False), + }).to_json() + + client = self._init() + + self.assertFalse(client.external_account_required()) + + def test_external_account_required_false_v1(self): + self.response.json.return_value = messages.Directory({ + 'meta': messages.Directory.Meta(external_account_required=False), + }).to_json() + + client = self._init() + + self.assertFalse(client.external_account_required()) + + +class ClientTest(ClientTestBase): + """Tests for acme.client.Client.""" + + def setUp(self): + super(ClientTest, self).setUp() + + self.directory = DIRECTORY_V1 + + # Registration + self.regr = self.regr.update( + terms_of_service='https://www.letsencrypt-demo.org/tos') + + # Request issuance + self.certr = messages.CertificateResource( + body=messages_test.CERT, authzrs=(self.authzr,), + uri='https://www.letsencrypt-demo.org/acme/cert/1', + cert_chain_uri='https://www.letsencrypt-demo.org/ca') + + from acme.client import Client + self.client = Client( + directory=self.directory, key=KEY, alg=jose.RS256, net=self.net) + + def test_init_downloads_directory(self): + uri = 'http://www.letsencrypt-demo.org/directory' + from acme.client import Client + self.client = Client( + directory=uri, key=KEY, alg=jose.RS256, net=self.net) + self.net.get.assert_called_once_with(uri) + + @mock.patch('acme.client.ClientNetwork') + def test_init_without_net(self, mock_net): + mock_net.return_value = mock.sentinel.net + alg = jose.RS256 + from acme.client import Client + self.client = Client( + directory=self.directory, key=KEY, alg=alg) + mock_net.called_once_with(KEY, alg=alg, verify_ssl=True) + self.assertEqual(self.client.net, mock.sentinel.net) + + def test_register(self): + # "Instance of 'Field' has no to_json/update member" bug: + self.response.status_code = http_client.CREATED + self.response.json.return_value = self.regr.body.to_json() + self.response.headers['Location'] = self.regr.uri + self.response.links.update({ + 'terms-of-service': {'url': self.regr.terms_of_service}, + }) + + self.assertEqual(self.regr, self.client.register(self.new_reg)) + # TODO: test POST call arguments + + def test_update_registration(self): + # "Instance of 'Field' has no to_json/update member" bug: + self.response.headers['Location'] = self.regr.uri + self.response.json.return_value = self.regr.body.to_json() + self.assertEqual(self.regr, self.client.update_registration(self.regr)) + # TODO: test POST call arguments + + # TODO: split here and separate test + self.response.json.return_value = self.regr.body.update( + contact=()).to_json() + + def test_deactivate_account(self): + self.response.headers['Location'] = self.regr.uri + self.response.json.return_value = self.regr.body.to_json() + self.assertEqual(self.regr, + self.client.deactivate_registration(self.regr)) + + def test_query_registration(self): + self.response.json.return_value = self.regr.body.to_json() + self.assertEqual(self.regr, self.client.query_registration(self.regr)) + + def test_agree_to_tos(self): + self.client.update_registration = mock.Mock() + self.client.agree_to_tos(self.regr) + regr = self.client.update_registration.call_args[0][0] + self.assertEqual(self.regr.terms_of_service, regr.body.agreement) + + def _prepare_response_for_request_challenges(self): + self.response.status_code = http_client.CREATED + self.response.headers['Location'] = self.authzr.uri + self.response.json.return_value = self.authz.to_json() + + def test_request_challenges(self): + self._prepare_response_for_request_challenges() + self.client.request_challenges(self.identifier) + self.net.post.assert_called_once_with( + self.directory.new_authz, + messages.NewAuthorization(identifier=self.identifier), + acme_version=1) + + def test_request_challenges_deprecated_arg(self): + self._prepare_response_for_request_challenges() + self.client.request_challenges(self.identifier, new_authzr_uri="hi") + self.net.post.assert_called_once_with( + self.directory.new_authz, + messages.NewAuthorization(identifier=self.identifier), + acme_version=1) + + def test_request_challenges_custom_uri(self): + self._prepare_response_for_request_challenges() + self.client.request_challenges(self.identifier) + self.net.post.assert_called_once_with( + 'https://www.letsencrypt-demo.org/acme/new-authz', mock.ANY, + acme_version=1) + + def test_request_challenges_unexpected_update(self): + self._prepare_response_for_request_challenges() + self.response.json.return_value = self.authz.update( + identifier=self.identifier.update(value='foo')).to_json() + self.assertRaises( + errors.UnexpectedUpdate, self.client.request_challenges, + self.identifier) + + def test_request_challenges_wildcard(self): + wildcard_identifier = messages.Identifier( + typ=messages.IDENTIFIER_FQDN, value='*.example.org') + self.assertRaises( + errors.WildcardUnsupportedError, self.client.request_challenges, + wildcard_identifier) + + def test_request_domain_challenges(self): + self.client.request_challenges = mock.MagicMock() + self.assertEqual( + self.client.request_challenges(self.identifier), + self.client.request_domain_challenges('example.com')) + + def test_answer_challenge(self): + self.response.links['up'] = {'url': self.challr.authzr_uri} + self.response.json.return_value = self.challr.body.to_json() + + chall_response = challenges.DNSResponse(validation=None) + + self.client.answer_challenge(self.challr.body, chall_response) + + # TODO: split here and separate test + self.assertRaises(errors.UnexpectedUpdate, self.client.answer_challenge, + self.challr.body.update(uri='foo'), chall_response) + + def test_answer_challenge_missing_next(self): + self.assertRaises( + errors.ClientError, self.client.answer_challenge, + self.challr.body, challenges.DNSResponse(validation=None)) + + def test_retry_after_date(self): + self.response.headers['Retry-After'] = 'Fri, 31 Dec 1999 23:59:59 GMT' + self.assertEqual( + datetime.datetime(1999, 12, 31, 23, 59, 59), + self.client.retry_after(response=self.response, default=10)) + + @mock.patch('acme.client.datetime') + def test_retry_after_invalid(self, dt_mock): + dt_mock.datetime.now.return_value = datetime.datetime(2015, 3, 27) + dt_mock.timedelta = datetime.timedelta + + self.response.headers['Retry-After'] = 'foooo' + self.assertEqual( + datetime.datetime(2015, 3, 27, 0, 0, 10), + self.client.retry_after(response=self.response, default=10)) + + @mock.patch('acme.client.datetime') + def test_retry_after_overflow(self, dt_mock): + dt_mock.datetime.now.return_value = datetime.datetime(2015, 3, 27) + dt_mock.timedelta = datetime.timedelta + dt_mock.datetime.side_effect = datetime.datetime + + self.response.headers['Retry-After'] = "Tue, 116 Feb 2016 11:50:00 MST" + self.assertEqual( + datetime.datetime(2015, 3, 27, 0, 0, 10), + self.client.retry_after(response=self.response, default=10)) + + @mock.patch('acme.client.datetime') + def test_retry_after_seconds(self, dt_mock): + dt_mock.datetime.now.return_value = datetime.datetime(2015, 3, 27) + dt_mock.timedelta = datetime.timedelta + + self.response.headers['Retry-After'] = '50' + self.assertEqual( + datetime.datetime(2015, 3, 27, 0, 0, 50), + self.client.retry_after(response=self.response, default=10)) + + @mock.patch('acme.client.datetime') + def test_retry_after_missing(self, dt_mock): + dt_mock.datetime.now.return_value = datetime.datetime(2015, 3, 27) + dt_mock.timedelta = datetime.timedelta + + self.assertEqual( + datetime.datetime(2015, 3, 27, 0, 0, 10), + self.client.retry_after(response=self.response, default=10)) + + def test_poll(self): + self.response.json.return_value = self.authzr.body.to_json() + self.assertEqual((self.authzr, self.response), + self.client.poll(self.authzr)) + + # TODO: split here and separate test + self.response.json.return_value = self.authz.update( + identifier=self.identifier.update(value='foo')).to_json() + self.assertRaises( + errors.UnexpectedUpdate, self.client.poll, self.authzr) + + def test_request_issuance(self): + self.response.content = CERT_DER + self.response.headers['Location'] = self.certr.uri + self.response.links['up'] = {'url': self.certr.cert_chain_uri} + self.assertEqual(self.certr, self.client.request_issuance( + messages_test.CSR, (self.authzr,))) + # TODO: check POST args + + def test_request_issuance_missing_up(self): + self.response.content = CERT_DER + self.response.headers['Location'] = self.certr.uri + self.assertEqual( + self.certr.update(cert_chain_uri=None), + self.client.request_issuance(messages_test.CSR, (self.authzr,))) + + def test_request_issuance_missing_location(self): + self.assertRaises( + errors.ClientError, self.client.request_issuance, + messages_test.CSR, (self.authzr,)) + + @mock.patch('acme.client.datetime') + @mock.patch('acme.client.time') + def test_poll_and_request_issuance(self, time_mock, dt_mock): + # clock.dt | pylint: disable=no-member + clock = mock.MagicMock(dt=datetime.datetime(2015, 3, 27)) + + def sleep(seconds): + """increment clock""" + clock.dt += datetime.timedelta(seconds=seconds) + time_mock.sleep.side_effect = sleep + + def now(): + """return current clock value""" + return clock.dt + dt_mock.datetime.now.side_effect = now + dt_mock.timedelta = datetime.timedelta + + def poll(authzr): # pylint: disable=missing-docstring + # record poll start time based on the current clock value + authzr.times.append(clock.dt) + + # suppose it takes 2 seconds for server to produce the + # result, increment clock + clock.dt += datetime.timedelta(seconds=2) + + if len(authzr.retries) == 1: # no more retries + done = mock.MagicMock(uri=authzr.uri, times=authzr.times) + done.body.status = authzr.retries[0] + return done, [] + + # response (2nd result tuple element) is reduced to only + # Retry-After header contents represented as integer + # seconds; authzr.retries is a list of Retry-After + # headers, head(retries) is peeled of as a current + # Retry-After header, and tail(retries) is persisted for + # later poll() calls + return (mock.MagicMock(retries=authzr.retries[1:], + uri=authzr.uri + '.', times=authzr.times), + authzr.retries[0]) + self.client.poll = mock.MagicMock(side_effect=poll) + + mintime = 7 + + def retry_after(response, default): + # pylint: disable=missing-docstring + # check that poll_and_request_issuance correctly passes mintime + self.assertEqual(default, mintime) + return clock.dt + datetime.timedelta(seconds=response) + self.client.retry_after = mock.MagicMock(side_effect=retry_after) + + def request_issuance(csr, authzrs): # pylint: disable=missing-docstring + return csr, authzrs + self.client.request_issuance = mock.MagicMock( + side_effect=request_issuance) + + csr = mock.MagicMock() + authzrs = ( + mock.MagicMock(uri='a', times=[], retries=( + 8, 20, 30, messages.STATUS_VALID)), + mock.MagicMock(uri='b', times=[], retries=( + 5, messages.STATUS_VALID)), + ) + + cert, updated_authzrs = self.client.poll_and_request_issuance( + csr, authzrs, mintime=mintime, + # make sure that max_attempts is per-authorization, rather + # than global + max_attempts=max(len(authzrs[0].retries), len(authzrs[1].retries))) + self.assertTrue(cert[0] is csr) + self.assertTrue(cert[1] is updated_authzrs) + self.assertEqual(updated_authzrs[0].uri, 'a...') + self.assertEqual(updated_authzrs[1].uri, 'b.') + self.assertEqual(updated_authzrs[0].times, [ + datetime.datetime(2015, 3, 27), + # a is scheduled for 10, but b is polling [9..11), so it + # will be picked up as soon as b is finished, without + # additional sleeping + datetime.datetime(2015, 3, 27, 0, 0, 11), + datetime.datetime(2015, 3, 27, 0, 0, 33), + datetime.datetime(2015, 3, 27, 0, 1, 5), + ]) + self.assertEqual(updated_authzrs[1].times, [ + datetime.datetime(2015, 3, 27, 0, 0, 2), + datetime.datetime(2015, 3, 27, 0, 0, 9), + ]) + self.assertEqual(clock.dt, datetime.datetime(2015, 3, 27, 0, 1, 7)) + + # CA sets invalid | TODO: move to a separate test + invalid_authzr = mock.MagicMock( + times=[], retries=[messages.STATUS_INVALID]) + self.assertRaises( + errors.PollError, self.client.poll_and_request_issuance, + csr, authzrs=(invalid_authzr,), mintime=mintime) + + # exceeded max_attempts | TODO: move to a separate test + self.assertRaises( + errors.PollError, self.client.poll_and_request_issuance, + csr, authzrs, mintime=mintime, max_attempts=2) + + def test_deactivate_authorization(self): + authzb = self.authzr.body.update(status=messages.STATUS_DEACTIVATED) + self.response.json.return_value = authzb.to_json() + authzr = self.client.deactivate_authorization(self.authzr) + self.assertEqual(authzb, authzr.body) + self.assertEqual(self.client.net.post.call_count, 1) + self.assertTrue(self.authzr.uri in self.net.post.call_args_list[0][0]) + + def test_check_cert(self): + self.response.headers['Location'] = self.certr.uri + self.response.content = CERT_DER + self.assertEqual(self.certr.update(body=messages_test.CERT), + self.client.check_cert(self.certr)) + + # TODO: split here and separate test + self.response.headers['Location'] = 'foo' + self.assertRaises( + errors.UnexpectedUpdate, self.client.check_cert, self.certr) + + def test_check_cert_missing_location(self): + self.response.content = CERT_DER + self.assertRaises( + errors.ClientError, self.client.check_cert, self.certr) + + def test_refresh(self): + self.client.check_cert = mock.MagicMock() + self.assertEqual( + self.client.check_cert(self.certr), self.client.refresh(self.certr)) + + def test_fetch_chain_no_up_link(self): + self.assertEqual([], self.client.fetch_chain(self.certr.update( + cert_chain_uri=None))) + + def test_fetch_chain_single(self): + # pylint: disable=protected-access + self.client._get_cert = mock.MagicMock() + self.client._get_cert.return_value = ( + mock.MagicMock(links={}), "certificate") + self.assertEqual([self.client._get_cert(self.certr.cert_chain_uri)[1]], + self.client.fetch_chain(self.certr)) + + def test_fetch_chain_max(self): + # pylint: disable=protected-access + up_response = mock.MagicMock(links={'up': {'url': 'http://cert'}}) + noup_response = mock.MagicMock(links={}) + self.client._get_cert = mock.MagicMock() + self.client._get_cert.side_effect = [ + (up_response, "cert")] * 9 + [(noup_response, "last_cert")] + chain = self.client.fetch_chain(self.certr, max_length=10) + self.assertEqual(chain, ["cert"] * 9 + ["last_cert"]) + + def test_fetch_chain_too_many(self): # recursive + # pylint: disable=protected-access + response = mock.MagicMock(links={'up': {'url': 'http://cert'}}) + self.client._get_cert = mock.MagicMock() + self.client._get_cert.return_value = (response, "certificate") + self.assertRaises(errors.Error, self.client.fetch_chain, self.certr) + + def test_revoke(self): + self.client.revoke(self.certr.body, self.rsn) + self.net.post.assert_called_once_with( + self.directory[messages.Revocation], mock.ANY, acme_version=1) + + def test_revocation_payload(self): + obj = messages.Revocation(certificate=self.certr.body, reason=self.rsn) + self.assertTrue('reason' in obj.to_partial_json().keys()) + self.assertEqual(self.rsn, obj.to_partial_json()['reason']) + + def test_revoke_bad_status_raises_error(self): + self.response.status_code = http_client.METHOD_NOT_ALLOWED + self.assertRaises( + errors.ClientError, + self.client.revoke, + self.certr, + self.rsn) + + +class ClientV2Test(ClientTestBase): + """Tests for acme.client.ClientV2.""" + + def setUp(self): + super(ClientV2Test, self).setUp() + + self.directory = DIRECTORY_V2 + + from acme.client import ClientV2 + self.client = ClientV2(self.directory, self.net) + + self.new_reg = self.new_reg.update(terms_of_service_agreed=True) + + self.authzr_uri2 = 'https://www.letsencrypt-demo.org/acme/authz/2' + self.authz2 = self.authz.update(identifier=messages.Identifier( + typ=messages.IDENTIFIER_FQDN, value='www.example.com'), + status=messages.STATUS_PENDING) + self.authzr2 = messages.AuthorizationResource( + body=self.authz2, uri=self.authzr_uri2) + + self.order = messages.Order( + identifiers=(self.authz.identifier, self.authz2.identifier), + status=messages.STATUS_PENDING, + authorizations=(self.authzr.uri, self.authzr_uri2), + finalize='https://www.letsencrypt-demo.org/acme/acct/1/order/1/finalize') + self.orderr = messages.OrderResource( + body=self.order, + uri='https://www.letsencrypt-demo.org/acme/acct/1/order/1', + authorizations=[self.authzr, self.authzr2], csr_pem=CSR_SAN_PEM) + + def test_new_account(self): + self.response.status_code = http_client.CREATED + self.response.json.return_value = self.regr.body.to_json() + self.response.headers['Location'] = self.regr.uri + + self.assertEqual(self.regr, self.client.new_account(self.new_reg)) + + def test_new_account_conflict(self): + self.response.status_code = http_client.OK + self.response.headers['Location'] = self.regr.uri + self.assertRaises(errors.ConflictError, self.client.new_account, self.new_reg) + + def test_new_order(self): + order_response = copy.deepcopy(self.response) + order_response.status_code = http_client.CREATED + order_response.json.return_value = self.order.to_json() + order_response.headers['Location'] = self.orderr.uri + self.net.post.return_value = order_response + + authz_response = copy.deepcopy(self.response) + authz_response.json.return_value = self.authz.to_json() + authz_response.headers['Location'] = self.authzr.uri + authz_response2 = self.response + authz_response2.json.return_value = self.authz2.to_json() + authz_response2.headers['Location'] = self.authzr2.uri + + with mock.patch('acme.client.ClientV2._post_as_get') as mock_post_as_get: + mock_post_as_get.side_effect = (authz_response, authz_response2) + self.assertEqual(self.client.new_order(CSR_SAN_PEM), self.orderr) + + @mock.patch('acme.client.datetime') + def test_poll_and_finalize(self, mock_datetime): + mock_datetime.datetime.now.return_value = datetime.datetime(2018, 2, 15) + mock_datetime.timedelta = datetime.timedelta + expected_deadline = mock_datetime.datetime.now() + datetime.timedelta(seconds=90) + + self.client.poll_authorizations = mock.Mock(return_value=self.orderr) + self.client.finalize_order = mock.Mock(return_value=self.orderr) + + self.assertEqual(self.client.poll_and_finalize(self.orderr), self.orderr) + self.client.poll_authorizations.assert_called_once_with(self.orderr, expected_deadline) + self.client.finalize_order.assert_called_once_with(self.orderr, expected_deadline) + + @mock.patch('acme.client.datetime') + def test_poll_authorizations_timeout(self, mock_datetime): + now_side_effect = [datetime.datetime(2018, 2, 15), + datetime.datetime(2018, 2, 16), + datetime.datetime(2018, 2, 17)] + mock_datetime.datetime.now.side_effect = now_side_effect + self.response.json.side_effect = [ + self.authz.to_json(), self.authz2.to_json(), self.authz2.to_json()] + + self.assertRaises( + errors.TimeoutError, self.client.poll_authorizations, self.orderr, now_side_effect[1]) + + def test_poll_authorizations_failure(self): + deadline = datetime.datetime(9999, 9, 9) + challb = self.challr.body.update(status=messages.STATUS_INVALID, + error=messages.Error.with_code('unauthorized')) + authz = self.authz.update(status=messages.STATUS_INVALID, challenges=(challb,)) + self.response.json.return_value = authz.to_json() + + self.assertRaises( + errors.ValidationError, self.client.poll_authorizations, self.orderr, deadline) + + def test_poll_authorizations_success(self): + deadline = datetime.datetime(9999, 9, 9) + updated_authz2 = self.authz2.update(status=messages.STATUS_VALID) + updated_authzr2 = messages.AuthorizationResource( + body=updated_authz2, uri=self.authzr_uri2) + updated_orderr = self.orderr.update(authorizations=[self.authzr, updated_authzr2]) + + self.response.json.side_effect = ( + self.authz.to_json(), self.authz2.to_json(), updated_authz2.to_json()) + self.assertEqual(self.client.poll_authorizations(self.orderr, deadline), updated_orderr) + + def test_finalize_order_success(self): + updated_order = self.order.update( + certificate='https://www.letsencrypt-demo.org/acme/cert/') + updated_orderr = self.orderr.update(body=updated_order, fullchain_pem=CERT_SAN_PEM) + + self.response.json.return_value = updated_order.to_json() + self.response.text = CERT_SAN_PEM + + deadline = datetime.datetime(9999, 9, 9) + self.assertEqual(self.client.finalize_order(self.orderr, deadline), updated_orderr) + + def test_finalize_order_error(self): + updated_order = self.order.update(error=messages.Error.with_code('unauthorized')) + self.response.json.return_value = updated_order.to_json() + + deadline = datetime.datetime(9999, 9, 9) + self.assertRaises(errors.IssuanceError, self.client.finalize_order, self.orderr, deadline) + + def test_finalize_order_timeout(self): + deadline = datetime.datetime.now() - datetime.timedelta(seconds=60) + self.assertRaises(errors.TimeoutError, self.client.finalize_order, self.orderr, deadline) + + def test_revoke(self): + self.client.revoke(messages_test.CERT, self.rsn) + self.net.post.assert_called_once_with( + self.directory["revokeCert"], mock.ANY, acme_version=2, + new_nonce_url=DIRECTORY_V2['newNonce']) + + def test_update_registration(self): + # "Instance of 'Field' has no to_json/update member" bug: + self.response.headers['Location'] = self.regr.uri + self.response.json.return_value = self.regr.body.to_json() + self.assertEqual(self.regr, self.client.update_registration(self.regr)) + self.assertNotEqual(self.client.net.account, None) + self.assertEqual(self.client.net.post.call_count, 2) + self.assertTrue(DIRECTORY_V2.newAccount in self.net.post.call_args_list[0][0]) + + self.response.json.return_value = self.regr.body.update( + contact=()).to_json() + + def test_external_account_required_true(self): + self.client.directory = messages.Directory({ + 'meta': messages.Directory.Meta(external_account_required=True) + }) + + self.assertTrue(self.client.external_account_required()) + + def test_external_account_required_false(self): + self.client.directory = messages.Directory({ + 'meta': messages.Directory.Meta(external_account_required=False) + }) + + self.assertFalse(self.client.external_account_required()) + + def test_external_account_required_default(self): + self.assertFalse(self.client.external_account_required()) + + def test_post_as_get(self): + with mock.patch('acme.client.ClientV2._authzr_from_response') as mock_client: + mock_client.return_value = self.authzr2 + + self.client.poll(self.authzr2) # pylint: disable=protected-access + + self.client.net.post.assert_called_once_with( + self.authzr2.uri, None, acme_version=2, + new_nonce_url='https://www.letsencrypt-demo.org/acme/new-nonce') + self.client.net.get.assert_not_called() + + class FakeError(messages.Error): + """Fake error to reproduce a malformed request ACME error""" + def __init__(self): # pylint: disable=super-init-not-called + pass + @property + def code(self): + return 'malformed' + self.client.net.post.side_effect = FakeError() + + self.client.poll(self.authzr2) # pylint: disable=protected-access + + self.client.net.get.assert_called_once_with(self.authzr2.uri) + + +class MockJSONDeSerializable(jose.JSONDeSerializable): + # pylint: disable=missing-docstring + def __init__(self, value): + self.value = value + + def to_partial_json(self): + return {'foo': self.value} + + @classmethod + def from_json(cls, jobj): + pass # pragma: no cover + + +class ClientNetworkTest(unittest.TestCase): + """Tests for acme.client.ClientNetwork.""" + + def setUp(self): + self.verify_ssl = mock.MagicMock() + self.wrap_in_jws = mock.MagicMock(return_value=mock.sentinel.wrapped) + + from acme.client import ClientNetwork + self.net = ClientNetwork( + key=KEY, alg=jose.RS256, verify_ssl=self.verify_ssl, + user_agent='acme-python-test') + + self.response = mock.MagicMock(ok=True, status_code=http_client.OK) + self.response.headers = {} + self.response.links = {} + + def test_init(self): + self.assertTrue(self.net.verify_ssl is self.verify_ssl) + + def test_wrap_in_jws(self): + # pylint: disable=protected-access + jws_dump = self.net._wrap_in_jws( + MockJSONDeSerializable('foo'), nonce=b'Tg', url="url", + acme_version=1) + jws = acme_jws.JWS.json_loads(jws_dump) + self.assertEqual(json.loads(jws.payload.decode()), {'foo': 'foo'}) + self.assertEqual(jws.signature.combined.nonce, b'Tg') + + def test_wrap_in_jws_v2(self): + self.net.account = {'uri': 'acct-uri'} + # pylint: disable=protected-access + jws_dump = self.net._wrap_in_jws( + MockJSONDeSerializable('foo'), nonce=b'Tg', url="url", + acme_version=2) + jws = acme_jws.JWS.json_loads(jws_dump) + self.assertEqual(json.loads(jws.payload.decode()), {'foo': 'foo'}) + self.assertEqual(jws.signature.combined.nonce, b'Tg') + self.assertEqual(jws.signature.combined.kid, u'acct-uri') + self.assertEqual(jws.signature.combined.url, u'url') + + def test_check_response_not_ok_jobj_no_error(self): + self.response.ok = False + self.response.json.return_value = {} + with mock.patch('acme.client.messages.Error.from_json') as from_json: + from_json.side_effect = jose.DeserializationError + # pylint: disable=protected-access + self.assertRaises( + errors.ClientError, self.net._check_response, self.response) + + def test_check_response_not_ok_jobj_error(self): + self.response.ok = False + self.response.json.return_value = messages.Error( + detail='foo', typ='serverInternal', title='some title').to_json() + # pylint: disable=protected-access + self.assertRaises( + messages.Error, self.net._check_response, self.response) + + def test_check_response_not_ok_no_jobj(self): + self.response.ok = False + self.response.json.side_effect = ValueError + # pylint: disable=protected-access + self.assertRaises( + errors.ClientError, self.net._check_response, self.response) + + def test_check_response_ok_no_jobj_ct_required(self): + self.response.json.side_effect = ValueError + for response_ct in [self.net.JSON_CONTENT_TYPE, 'foo']: + self.response.headers['Content-Type'] = response_ct + # pylint: disable=protected-access + self.assertRaises( + errors.ClientError, self.net._check_response, self.response, + content_type=self.net.JSON_CONTENT_TYPE) + + def test_check_response_ok_no_jobj_no_ct(self): + self.response.json.side_effect = ValueError + for response_ct in [self.net.JSON_CONTENT_TYPE, 'foo']: + self.response.headers['Content-Type'] = response_ct + # pylint: disable=protected-access,no-value-for-parameter + self.assertEqual( + self.response, self.net._check_response(self.response)) + + def test_check_response_conflict(self): + self.response.ok = False + self.response.status_code = 409 + # pylint: disable=protected-access + self.assertRaises(errors.ConflictError, self.net._check_response, self.response) + + def test_check_response_jobj(self): + self.response.json.return_value = {} + for response_ct in [self.net.JSON_CONTENT_TYPE, 'foo']: + self.response.headers['Content-Type'] = response_ct + # pylint: disable=protected-access,no-value-for-parameter + self.assertEqual( + self.response, self.net._check_response(self.response)) + + def test_send_request(self): + self.net.session = mock.MagicMock() + self.net.session.request.return_value = self.response + # pylint: disable=protected-access + self.assertEqual(self.response, self.net._send_request( + 'HEAD', 'http://example.com/', 'foo', bar='baz')) + self.net.session.request.assert_called_once_with( + 'HEAD', 'http://example.com/', 'foo', + headers=mock.ANY, verify=mock.ANY, timeout=mock.ANY, bar='baz') + + @mock.patch('acme.client.logger') + def test_send_request_get_der(self, mock_logger): + self.net.session = mock.MagicMock() + self.net.session.request.return_value = mock.MagicMock( + ok=True, status_code=http_client.OK, + headers={"Content-Type": "application/pkix-cert"}, + content=b"hi") + # pylint: disable=protected-access + self.net._send_request('HEAD', 'http://example.com/', 'foo', + timeout=mock.ANY, bar='baz') + mock_logger.debug.assert_called_with( + 'Received response:\nHTTP %d\n%s\n\n%s', 200, + 'Content-Type: application/pkix-cert', b'aGk=') + + def test_send_request_post(self): + self.net.session = mock.MagicMock() + self.net.session.request.return_value = self.response + # pylint: disable=protected-access + self.assertEqual(self.response, self.net._send_request( + 'POST', 'http://example.com/', 'foo', data='qux', bar='baz')) + self.net.session.request.assert_called_once_with( + 'POST', 'http://example.com/', 'foo', + headers=mock.ANY, verify=mock.ANY, timeout=mock.ANY, data='qux', bar='baz') + + def test_send_request_verify_ssl(self): + # pylint: disable=protected-access + for verify in True, False: + self.net.session = mock.MagicMock() + self.net.session.request.return_value = self.response + self.net.verify_ssl = verify + # pylint: disable=protected-access + self.assertEqual( + self.response, + self.net._send_request('GET', 'http://example.com/')) + self.net.session.request.assert_called_once_with( + 'GET', 'http://example.com/', verify=verify, + timeout=mock.ANY, headers=mock.ANY) + + def test_send_request_user_agent(self): + self.net.session = mock.MagicMock() + # pylint: disable=protected-access + self.net._send_request('GET', 'http://example.com/', + headers={'bar': 'baz'}) + self.net.session.request.assert_called_once_with( + 'GET', 'http://example.com/', verify=mock.ANY, + timeout=mock.ANY, + headers={'User-Agent': 'acme-python-test', 'bar': 'baz'}) + + self.net._send_request('GET', 'http://example.com/', + headers={'User-Agent': 'foo2'}) + self.net.session.request.assert_called_with( + 'GET', 'http://example.com/', + verify=mock.ANY, timeout=mock.ANY, headers={'User-Agent': 'foo2'}) + + def test_send_request_timeout(self): + self.net.session = mock.MagicMock() + # pylint: disable=protected-access + self.net._send_request('GET', 'http://example.com/', + headers={'bar': 'baz'}) + self.net.session.request.assert_called_once_with( + mock.ANY, mock.ANY, verify=mock.ANY, headers=mock.ANY, + timeout=45) + + def test_del(self, close_exception=None): + sess = mock.MagicMock() + + if close_exception is not None: + sess.close.side_effect = close_exception + + self.net.session = sess + del self.net + sess.close.assert_called_once_with() + + def test_del_error(self): + self.test_del(ReferenceError) + + @mock.patch('acme.client.requests') + def test_requests_error_passthrough(self, mock_requests): + mock_requests.exceptions = requests.exceptions + mock_requests.request.side_effect = requests.exceptions.RequestException + # pylint: disable=protected-access + self.assertRaises(requests.exceptions.RequestException, + self.net._send_request, 'GET', 'uri') + + def test_urllib_error(self): + # Using a connection error to test a properly formatted error message + try: + # pylint: disable=protected-access + self.net._send_request('GET', "http://localhost:19123/nonexistent.txt") + + # Value Error Generated Exceptions + except ValueError as y: + self.assertEqual("Requesting localhost/nonexistent: " + "Connection refused", str(y)) + + # Requests Library Exceptions + except requests.exceptions.ConnectionError as z: #pragma: no cover + self.assertTrue("'Connection aborted.'" in str(z) or "[WinError 10061]" in str(z)) + + +class ClientNetworkWithMockedResponseTest(unittest.TestCase): + """Tests for acme.client.ClientNetwork which mock out response.""" + + def setUp(self): + from acme.client import ClientNetwork + self.net = ClientNetwork(key=None, alg=None) + + self.response = mock.MagicMock(ok=True, status_code=http_client.OK) + self.response.headers = {} + self.response.links = {} + self.response.checked = False + self.acmev1_nonce_response = mock.MagicMock(ok=False, + status_code=http_client.METHOD_NOT_ALLOWED) + self.acmev1_nonce_response.headers = {} + self.obj = mock.MagicMock() + self.wrapped_obj = mock.MagicMock() + self.content_type = mock.sentinel.content_type + + self.all_nonces = [ + jose.b64encode(b'Nonce'), + jose.b64encode(b'Nonce2'), jose.b64encode(b'Nonce3')] + self.available_nonces = self.all_nonces[:] + + def send_request(*args, **kwargs): + # pylint: disable=unused-argument,missing-docstring + self.assertFalse("new_nonce_url" in kwargs) + method = args[0] + uri = args[1] + if method == 'HEAD' and uri != "new_nonce_uri": + response = self.acmev1_nonce_response + else: + response = self.response + + if self.available_nonces: + response.headers = { + self.net.REPLAY_NONCE_HEADER: + self.available_nonces.pop().decode()} + else: + response.headers = {} + return response + + # pylint: disable=protected-access + self.net._send_request = self.send_request = mock.MagicMock( + side_effect=send_request) + self.net._check_response = self.check_response + self.net._wrap_in_jws = mock.MagicMock(return_value=self.wrapped_obj) + + def check_response(self, response, content_type): + # pylint: disable=missing-docstring + self.assertEqual(self.response, response) + self.assertEqual(self.content_type, content_type) + self.assertTrue(self.response.ok) + self.response.checked = True + return self.response + + def test_head(self): + self.assertEqual(self.acmev1_nonce_response, self.net.head( + 'http://example.com/', 'foo', bar='baz')) + self.send_request.assert_called_once_with( + 'HEAD', 'http://example.com/', 'foo', bar='baz') + + def test_head_v2(self): + self.assertEqual(self.response, self.net.head( + 'new_nonce_uri', 'foo', bar='baz')) + self.send_request.assert_called_once_with( + 'HEAD', 'new_nonce_uri', 'foo', bar='baz') + + def test_get(self): + self.assertEqual(self.response, self.net.get( + 'http://example.com/', content_type=self.content_type, bar='baz')) + self.assertTrue(self.response.checked) + self.send_request.assert_called_once_with( + 'GET', 'http://example.com/', bar='baz') + + def test_post_no_content_type(self): + self.content_type = self.net.JOSE_CONTENT_TYPE + self.assertEqual(self.response, self.net.post('uri', self.obj)) + self.assertTrue(self.response.checked) + + def test_post(self): + # pylint: disable=protected-access + self.assertEqual(self.response, self.net.post( + 'uri', self.obj, content_type=self.content_type)) + self.assertTrue(self.response.checked) + self.net._wrap_in_jws.assert_called_once_with( + self.obj, jose.b64decode(self.all_nonces.pop()), "uri", 1) + + self.available_nonces = [] + self.assertRaises(errors.MissingNonce, self.net.post, + 'uri', self.obj, content_type=self.content_type) + self.net._wrap_in_jws.assert_called_with( + self.obj, jose.b64decode(self.all_nonces.pop()), "uri", 1) + + def test_post_wrong_initial_nonce(self): # HEAD + self.available_nonces = [b'f', jose.b64encode(b'good')] + self.assertRaises(errors.BadNonce, self.net.post, 'uri', + self.obj, content_type=self.content_type) + + def test_post_wrong_post_response_nonce(self): + self.available_nonces = [jose.b64encode(b'good'), b'f'] + self.assertRaises(errors.BadNonce, self.net.post, 'uri', + self.obj, content_type=self.content_type) + + def test_post_failed_retry(self): + check_response = mock.MagicMock() + check_response.side_effect = messages.Error.with_code('badNonce') + + # pylint: disable=protected-access + self.net._check_response = check_response + self.assertRaises(messages.Error, self.net.post, 'uri', + self.obj, content_type=self.content_type) + + def test_post_not_retried(self): + check_response = mock.MagicMock() + check_response.side_effect = [messages.Error.with_code('malformed'), + self.response] + + # pylint: disable=protected-access + self.net._check_response = check_response + self.assertRaises(messages.Error, self.net.post, 'uri', + self.obj, content_type=self.content_type) + + def test_post_successful_retry(self): + post_once = mock.MagicMock() + post_once.side_effect = [messages.Error.with_code('badNonce'), + self.response] + + # pylint: disable=protected-access + self.assertEqual(self.response, self.net.post( + 'uri', self.obj, content_type=self.content_type)) + + def test_head_get_post_error_passthrough(self): + self.send_request.side_effect = requests.exceptions.RequestException + for method in self.net.head, self.net.get: + self.assertRaises( + requests.exceptions.RequestException, method, 'GET', 'uri') + self.assertRaises(requests.exceptions.RequestException, + self.net.post, 'uri', obj=self.obj) + + def test_post_bad_nonce_head(self): + # pylint: disable=protected-access + # regression test for https://github.com/certbot/certbot/issues/6092 + bad_response = mock.MagicMock(ok=False, status_code=http_client.SERVICE_UNAVAILABLE) + self.net._send_request = mock.MagicMock() + self.net._send_request.return_value = bad_response + self.content_type = None + check_response = mock.MagicMock() + self.net._check_response = check_response + self.assertRaises(errors.ClientError, self.net.post, 'uri', + self.obj, content_type=self.content_type, acme_version=2, + new_nonce_url='new_nonce_uri') + self.assertEqual(check_response.call_count, 1) + + def test_new_nonce_uri_removed(self): + self.content_type = None + self.net.post('uri', self.obj, content_type=None, + acme_version=2, new_nonce_url='new_nonce_uri') + + +class ClientNetworkSourceAddressBindingTest(unittest.TestCase): + """Tests that if ClientNetwork has a source IP set manually, the underlying library has + used the provided source address.""" + + def setUp(self): + self.source_address = "8.8.8.8" + + def test_source_address_set(self): + from acme.client import ClientNetwork + net = ClientNetwork(key=None, alg=None, source_address=self.source_address) + for adapter in net.session.adapters.values(): + self.assertTrue(self.source_address in adapter.source_address) + + def test_behavior_assumption(self): + """This is a test that guardrails the HTTPAdapter behavior so that if the default for + a Session() changes, the assumptions here aren't violated silently.""" + from acme.client import ClientNetwork + # Source address not specified, so the default adapter type should be bound -- this + # test should fail if the default adapter type is changed by requests + net = ClientNetwork(key=None, alg=None) + session = requests.Session() + for scheme in session.adapters.keys(): + client_network_adapter = net.session.adapters.get(scheme) + default_adapter = session.adapters.get(scheme) + self.assertEqual(client_network_adapter.__class__, default_adapter.__class__) + +if __name__ == '__main__': + unittest.main() # pragma: no cover |