#!/usr/local/bin/python3.11
# -*- coding: utf-8 -*-

""" Python code to start a RIPE Atlas UDM (User-Defined
Measurement). This one is to test X.509/PKIX certificates in TLS servers.

You'll need an API key in ~/.atlas/auth.

After launching the measurement, it downloads the results and analyzes
them, displaying the name ("subject" in X.509 parlance) or issuer.

Stéphane Bortzmeyer <stephane+frama@bortzmeyer.org>
"""

import json
import time
import os
import string
import re
import sys
import time
import socket
import collections
import copy

import Blaeu

# https://github.com/pyca/pyopenssl https://pyopenssl.readthedocs.org/en/stable/
import OpenSSL.crypto
import cryptography

config = Blaeu.Config()
# Default values
config.display = "n" #Name
config.sni = True
# Override what's in the Blaeu package
config.port = 443

class Set():
    def __init__(self):
        self.total = 0

def usage(msg=None):
    print("Usage: %s target-name-or-IP" % sys.argv[0], file=sys.stderr)
    config.usage(msg)
    print("""Also:
    --issuer or -I : displays the issuer (default is to display the name)
    --key or -k : displays the public key (default is to display the name)
    --serial or -S : displays the serial number (default is to display the name)
    --expiration or -E : displays the expiration datetime (default is to display the name)
    --no-sni : do not send the SNI (Server Name Indication) (default is to send it)""",
          file=sys.stderr)

def format_name(n):
    result = ""
    components = n.get_components()
    for (k, v) in components:
        result += "/%s=%s" % (k.decode(), v.decode())
    return result

def specificParse(config, option, value):
    result = True
    if option == "--issuer" or option == "-I":
        config.display = "i"
    elif option == "--key" or option == "-k":
        config.display = "k"
    elif option == "--serial" or option == "-S":
        config.display = "s"
    elif option == "--expiration" or option == "-E":
        config.display = "e"
    elif option == "--no-sni":
        config.sni = False
    else:
        result = False
    return result

(args, data) = config.parse("IkSE", ["issuer", "serial", "expiration", "key", "no-sni"],
                            specificParse, usage)

if len(args) != 1:
    usage("Not the good number of arguments")
    sys.exit(1)
target = args[0]

if config.measurement_id is None:
    data["definitions"][0]["target"] = target
    data["definitions"][0]["type"] = "sslcert"
    data["definitions"][0]["description"] = "X.509 cert of %s" % target
    del data["definitions"][0]["size"] # Meaningless argument
    if target.find(':') > -1: # TODO: or use is_ip_address(str) from blaeu-reach?
        config.ipv4 = False
        af = 6
        if config.include is not None:
            data["probes"][0]["tags"]["include"] = copy.copy(config.include)
            data["probes"][0]["tags"]["include"].append("system-ipv6-works")
        else:
            data["probes"][0]["tags"]["include"] = ["system-ipv6-works",]
    elif re.match("^[0-9.]+$", target):
        config.ipv4 = True
        af = 4
        if config.include is not None:
            data["probes"][0]["tags"]["include"] = copy.copy(config.include)
            data["probes"][0]["tags"]["include"].append("system-ipv4-works")
        else:
            data["probes"][0]["tags"]["include"] = ["system-ipv4-works",]
    else:
        # Hostname
        if config.ipv4:
            af = 4
        else:
            af = 6
    data["definitions"][0]['af'] = af
    if config.sni:
        data["definitions"][0]['hostname'] = target

    if config.verbose:
        print(data)

    try:
        measurement = Blaeu.Measurement(data)
    except Blaeu.RequestSubmissionError as error:
        print(Blaeu.format_error(error), file=sys.stderr)
        sys.exit(1)        
    if config.verbose:
            print("Measurement #%s to %s uses %i probes" % (measurement.id, target,
                                                        measurement.num_probes))
    rdata = measurement.results(wait=True, percentage_required=config.percentage_required)
else:
    measurement = Blaeu.Measurement(data=None, id=config.measurement_id)
    rdata = measurement.results(wait=False)

sets = collections.defaultdict(Set)
if config.display_probe_asns:
    config.display_probes = True
if config.display_probes:
    probes_sets = collections.defaultdict(Set)
print(("%s probes reported" % len(rdata)))
for result in rdata:
        if config.display_probes:
            probe_id = result["prb_id"]
        if config.display_probe_asns:
            details = Blaeu.ProbeCache.cache_probe_id(config.cache_probes, probe_id) \
                if config.cache_probes else Blaeu.Probe(probe_id)
            asn = getattr(details, "asn_v%i" % (4 if config.ipv4 else 6), None)
        if 'cert' in result:
                # TODO: handle chains of certificates
                x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, str(result['cert'][0]))
                detail = ""
                content = format_name(x509.get_subject())
                if config.display == "i":
                    content = format_name(x509.get_issuer())
                elif config.display == "k":
                    key = x509.get_pubkey()
                    content = "%s, type %s, %s bits" % \
                        (key.to_cryptography_key().public_bytes(cryptography.hazmat.primitives.serialization.Encoding.PEM, cryptography.hazmat.primitives.serialization.PublicFormat.SubjectPublicKeyInfo).decode().replace("\n", " ").replace("-----BEGIN PUBLIC KEY----- ", "")[:80] + "...",
                         key.type(), key.bits())
                elif config.display == "s":
                    content = format(x509.get_serial_number(), '05x')
                elif config.display == "e":
                    if x509.has_expired():
                        detail = " (EXPIRED)"
                    # TODO: better format of the date?
                    content = "%s%s" % (x509.get_notAfter().decode(), detail)
                value = "%s%s" % (content, detail) 
        else:
            if 'err' in result:
                error = result['err']
            elif 'alert' in result:
                error = result['alert']
            else:
                error = "UNKNOWN ERROR"
            value = "FAILED TO GET A CERT: %s" % error
        sets[value].total += 1
        if config.display_probes:
            if config.display_probe_asns:
                info = [probe_id, asn]
            else:
                info = probe_id
            if value in probes_sets:
                probes_sets[value].append(info)
            else:
                probes_sets[value] = [info,]
sets_data = sorted(sets, key=lambda s: sets[s].total, reverse=False)
for myset in sets_data:
    detail = ""
    if config.display_probes:
        detail = "(probes %s)" % probes_sets[myset]
    print("[%s] : %i occurrences %s" % (myset, sets[myset].total, detail))

print(("Test #%s done at %s" % (measurement.id,
                                time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()))))
