240 lines
8.8 KiB
Python
Executable File
240 lines
8.8 KiB
Python
Executable File
import logging
|
|
import re
|
|
import os
|
|
import socket
|
|
import OpenSSL
|
|
import time
|
|
import sys
|
|
|
|
from datetime import datetime
|
|
from datetime import tzinfo
|
|
from datetime import timedelta
|
|
from http.client import HTTPConnection
|
|
from urllib.parse import urlparse
|
|
|
|
|
|
SEC_PER_DAY = 24 * 60 * 60
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class MDCertUtil(object):
|
|
# Utility class for inspecting certificates in test cases
|
|
# Uses PyOpenSSL: https://pyopenssl.org/en/stable/index.html
|
|
|
|
@classmethod
|
|
def create_self_signed_cert(cls, path, name_list, valid_days, serial=1000):
|
|
domain = name_list[0]
|
|
if not os.path.exists(path):
|
|
os.makedirs(path)
|
|
|
|
cert_file = os.path.join(path, 'pubcert.pem')
|
|
pkey_file = os.path.join(path, 'privkey.pem')
|
|
# create a key pair
|
|
if os.path.exists(pkey_file):
|
|
key_buffer = open(pkey_file, 'rt').read()
|
|
k = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, key_buffer)
|
|
else:
|
|
k = OpenSSL.crypto.PKey()
|
|
k.generate_key(OpenSSL.crypto.TYPE_RSA, 2048)
|
|
|
|
# create a self-signed cert
|
|
cert = OpenSSL.crypto.X509()
|
|
cert.get_subject().C = "DE"
|
|
cert.get_subject().ST = "NRW"
|
|
cert.get_subject().L = "Muenster"
|
|
cert.get_subject().O = "greenbytes GmbH"
|
|
cert.get_subject().CN = domain
|
|
cert.set_serial_number(serial)
|
|
cert.gmtime_adj_notBefore(valid_days["notBefore"] * SEC_PER_DAY)
|
|
cert.gmtime_adj_notAfter(valid_days["notAfter"] * SEC_PER_DAY)
|
|
cert.set_issuer(cert.get_subject())
|
|
|
|
cert.add_extensions([OpenSSL.crypto.X509Extension(
|
|
b"subjectAltName", False, b", ".join(map(lambda n: b"DNS:" + n.encode(), name_list))
|
|
)])
|
|
cert.set_pubkey(k)
|
|
cert.sign(k, 'sha1')
|
|
|
|
open(cert_file, "wt").write(
|
|
OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert).decode('utf-8'))
|
|
open(pkey_file, "wt").write(
|
|
OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, k).decode('utf-8'))
|
|
|
|
@classmethod
|
|
def load_server_cert(cls, host_ip, host_port, host_name, tls=None, ciphers=None):
|
|
ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
|
|
if tls is not None and tls != 1.0:
|
|
ctx.set_options(OpenSSL.SSL.OP_NO_TLSv1)
|
|
if tls is not None and tls != 1.1:
|
|
ctx.set_options(OpenSSL.SSL.OP_NO_TLSv1_1)
|
|
if tls is not None and tls != 1.2:
|
|
ctx.set_options(OpenSSL.SSL.OP_NO_TLSv1_2)
|
|
if tls is not None and tls != 1.3 and hasattr(OpenSSL.SSL, "OP_NO_TLSv1_3"):
|
|
ctx.set_options(OpenSSL.SSL.OP_NO_TLSv1_3)
|
|
if ciphers is not None:
|
|
ctx.set_cipher_list(ciphers)
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
connection = OpenSSL.SSL.Connection(ctx, s)
|
|
connection.connect((host_ip, int(host_port)))
|
|
connection.setblocking(1)
|
|
connection.set_tlsext_host_name(host_name.encode('utf-8'))
|
|
connection.do_handshake()
|
|
peer_cert = connection.get_peer_certificate()
|
|
return MDCertUtil(None, cert=peer_cert)
|
|
|
|
@classmethod
|
|
def parse_pem_cert(cls, text):
|
|
cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, text.encode('utf-8'))
|
|
return MDCertUtil(None, cert=cert)
|
|
|
|
@classmethod
|
|
def get_plain(cls, url, timeout):
|
|
server = urlparse(url)
|
|
try_until = time.time() + timeout
|
|
while time.time() < try_until:
|
|
# noinspection PyBroadException
|
|
try:
|
|
c = HTTPConnection(server.hostname, server.port, timeout=timeout)
|
|
c.request('GET', server.path)
|
|
resp = c.getresponse()
|
|
data = resp.read()
|
|
c.close()
|
|
return data
|
|
except IOError:
|
|
log.debug("connect error:", sys.exc_info()[0])
|
|
time.sleep(.1)
|
|
except:
|
|
log.error("Unexpected error:", sys.exc_info()[0])
|
|
log.error("Unable to contact server after %d sec" % timeout)
|
|
return None
|
|
|
|
def __init__(self, cert_path, cert=None):
|
|
if cert_path is not None:
|
|
self.cert_path = cert_path
|
|
# load certificate and private key
|
|
if cert_path.startswith("http"):
|
|
cert_data = self.get_plain(cert_path, 1)
|
|
else:
|
|
cert_data = MDCertUtil._load_binary_file(cert_path)
|
|
|
|
for file_type in (OpenSSL.crypto.FILETYPE_PEM, OpenSSL.crypto.FILETYPE_ASN1):
|
|
try:
|
|
self.cert = OpenSSL.crypto.load_certificate(file_type, cert_data)
|
|
except Exception as error:
|
|
self.error = error
|
|
if cert is not None:
|
|
self.cert = cert
|
|
|
|
if self.cert is None:
|
|
raise self.error
|
|
|
|
def get_issuer(self):
|
|
return self.cert.get_issuer()
|
|
|
|
def get_serial(self):
|
|
# the string representation of a serial number is not unique. Some
|
|
# add leading 0s to align with word boundaries.
|
|
return ("%lx" % (self.cert.get_serial_number())).upper()
|
|
|
|
def same_serial_as(self, other):
|
|
if isinstance(other, MDCertUtil):
|
|
return self.cert.get_serial_number() == other.cert.get_serial_number()
|
|
elif isinstance(other, OpenSSL.crypto.X509):
|
|
return self.cert.get_serial_number() == other.get_serial_number()
|
|
elif isinstance(other, str):
|
|
# assume a hex number
|
|
return self.cert.get_serial_number() == int(other, 16)
|
|
elif isinstance(other, int):
|
|
return self.cert.get_serial_number() == other
|
|
return False
|
|
|
|
def get_not_before(self):
|
|
tsp = self.cert.get_notBefore()
|
|
return self._parse_tsp(tsp)
|
|
|
|
def get_not_after(self):
|
|
tsp = self.cert.get_notAfter()
|
|
return self._parse_tsp(tsp)
|
|
|
|
def get_cn(self):
|
|
return self.cert.get_subject().CN
|
|
|
|
def get_key_length(self):
|
|
return self.cert.get_pubkey().bits()
|
|
|
|
def get_san_list(self):
|
|
text = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_TEXT, self.cert).decode("utf-8")
|
|
m = re.search(r"X509v3 Subject Alternative Name:\s*(.*)", text)
|
|
sans_list = []
|
|
if m:
|
|
sans_list = m.group(1).split(",")
|
|
|
|
def _strip_prefix(s):
|
|
return s.split(":")[1] if s.strip().startswith("DNS:") else s.strip()
|
|
return list(map(_strip_prefix, sans_list))
|
|
|
|
def get_must_staple(self):
|
|
text = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_TEXT, self.cert).decode("utf-8")
|
|
m = re.search(r"1.3.6.1.5.5.7.1.24:\s*\n\s*0....", text)
|
|
if not m:
|
|
# Newer openssl versions print this differently
|
|
m = re.search(r"TLS Feature:\s*\n\s*status_request\s*\n", text)
|
|
return m is not None
|
|
|
|
@classmethod
|
|
def validate_privkey(cls, privkey_path, passphrase=None):
|
|
privkey_data = cls._load_binary_file(privkey_path)
|
|
if passphrase:
|
|
privkey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, privkey_data, passphrase)
|
|
else:
|
|
privkey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, privkey_data)
|
|
return privkey.check()
|
|
|
|
def validate_cert_matches_priv_key(self, privkey_path):
|
|
# Verifies that the private key and cert match.
|
|
privkey_data = MDCertUtil._load_binary_file(privkey_path)
|
|
privkey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, privkey_data)
|
|
context = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
|
|
context.use_privatekey(privkey)
|
|
context.use_certificate(self.cert)
|
|
context.check_privatekey()
|
|
|
|
# --------- _utils_ ---------
|
|
|
|
def astr(self, s):
|
|
return s.decode('utf-8')
|
|
|
|
def _parse_tsp(self, tsp):
|
|
# timestampss returned by PyOpenSSL are bytes
|
|
# parse date and time part
|
|
s = ("%s-%s-%s %s:%s:%s" % (self.astr(tsp[0:4]), self.astr(tsp[4:6]), self.astr(tsp[6:8]),
|
|
self.astr(tsp[8:10]), self.astr(tsp[10:12]), self.astr(tsp[12:14])))
|
|
timestamp = datetime.strptime(s, '%Y-%m-%d %H:%M:%S')
|
|
# adjust timezone
|
|
tz_h, tz_m = 0, 0
|
|
m = re.match(r"([+\-]\d{2})(\d{2})", self.astr(tsp[14:]))
|
|
if m:
|
|
tz_h, tz_m = int(m.group(1)), int(m.group(2)) if tz_h > 0 else -1 * int(m.group(2))
|
|
return timestamp.replace(tzinfo=self.FixedOffset(60 * tz_h + tz_m))
|
|
|
|
@classmethod
|
|
def _load_binary_file(cls, path):
|
|
with open(path, mode="rb") as file:
|
|
return file.read()
|
|
|
|
class FixedOffset(tzinfo):
|
|
|
|
def __init__(self, offset):
|
|
self.__offset = timedelta(minutes=offset)
|
|
|
|
def utcoffset(self, dt):
|
|
return self.__offset
|
|
|
|
def tzname(self, dt):
|
|
return None
|
|
|
|
def dst(self, dt):
|
|
return timedelta(0)
|