#!/usr/bin/python3
#
# oFono - Open Source Telephony
# Copyright (C) 2023  Cruise, LLC
#
# SPDX-License-Identifier: GPL-2.0-only
import xml.etree.ElementTree as ET
import sys
import json
import bisect
from argparse import ArgumentParser, FileType
from pathlib import Path
import random
import struct
import ctypes

class ProviderInfo:
    sort_order_map = { v : pos for pos, v in
                      enumerate( ['name',
                                  'apn',
                                  'type',
                                  'protocol',
                                  'mmsc',
                                  'mmsproxy',
                                  'authentication',
                                  'username',
                                  'password',
                                  'tags'] ) }

    @classmethod
    def rawimport(cls, entry):
        if 'name' not in entry:
            raise SystemExit('No name for entry: ' + str(entry))

        info = ProviderInfo(entry['name'])

        for networkid in entry.get('ids', []):
            if not info.add_id(networkid):
                raise SystemExit('Invalid network id: ' + str(networkid))

        if 'spn' in entry:
            if not info.set_spn(entry['spn']):
                raise SystemExit('Invalid spn: ' + str(spn))

        for apn in entry.get('apns', []):
            if not info.add_context(apn):
                raise SystemExit('Invalid apn: ' + str(apn))

        if not info.is_valid():
            raise SystemExit('Invalid entry: ' + str(entry))

        return info

    def __init__(self, name):
        self.context_list = []
        self.mccmnc_list = []
        self.name = name
        self.spn = None

    @staticmethod
    def is_valid_id(id_string, expected_lengths):
        """
        Check if the identifier string is valid.

        Parameters:
        - id_string: The id string to check.
        - expected_lengths (tuple): A tuple representing the valid range of
          lengths.

        Returns:
        - bool: True if the MCC string is valid, False otherwise.
        """
        if not id_string.isdigit():
            return False

        if len(id_string) not in expected_lengths:
            return False

        if int(id_string) == 0:
            return False

        return True

    def add_mccmnc(self, mcc, mnc):
        if not self.is_valid_id(mcc, (3,)) or not self.is_valid_id(mnc, (2, 3)):
            return False

        bisect.insort(self.mccmnc_list, mcc + mnc)
        return True

    def add_id(self, mccmnc):
        if not self.is_valid_id(mccmnc, (5,6)):
            return False

        if int(mccmnc[:3]) == 0 or int(mccmnc[3:]) == 0:
            return False

        bisect.insort(self.mccmnc_list, mccmnc)
        return True

    def set_spn(self, spn):
        if len(spn) == 0 or len(spn) > 254:
            return False

        self.spn = spn
        return True

    def add_context(self, info):
        info = dict(sorted(info.items(),
                      key = lambda pair: self.sort_order_map[pair[0]]))
        self.context_list.append(info)

        return True

    def is_valid(self):
        return len(self.context_list) and len(self.mccmnc_list)

    def __str__(self):
        s = 'Provider \'' + self.name + '\''

        if (self.spn != None):
            s += ' [SPN:\'' + self.spn + '\']'

        s+= ' ' + str(self.mccmnc_list) + '\n'

        for context in self.context_list:
            s += '\t' + str(context) + '\n'

        return s

class MobileBroadbandProviderInfo:
    usage_to_type = { 'internet' : ['internet'],
                        'mms' : ['mms'],
                        'wap' : ['wap'],
                        'mms-internet-hipri' : ['internet', 'mms'],
                        'mms-internet-hipri-fota' : ['internet','mms'],
                     }
    @classmethod
    def type_from_usage(cls, usage):
        return cls.usage_to_type[usage]

    def __init__(self, xml_path):
        self.tree = ET.parse(xml_path)

    def parse(self, xml_path):
        providers = []

        try:
            tree = ET.parse(xml_path)
            root = tree.getroot()

            for provider in root.findall('.//provider'):
                name = provider.find('name')
                if name is None or not name.text:
                    continue;

                info = ProviderInfo(name.text)

                for networkid in provider.findall('gsm/network-id'):
                    info.add_mccmnc(networkid.get('mcc'), networkid.get('mnc'))

                for apn in provider.findall('gsm/apn'):
                    context = {}

                    context['apn'] = apn.get('value')
                    if context['apn'] == None:
                        continue

                    # Usage is missing for some APNs, skip such contexts for now
                    usage = apn.find('usage')
                    if usage is None or usage.get('type') is None:
                        continue;

                    context['type'] = self.type_from_usage(usage.get('type'))
                    if context['type'] == None:
                        sys.stderr.write("Unable to convert type: %s\n" %
                                            usage.get('type'))
                        continue

                    if 'mms' in context['type']:
                        mmsc = apn.find('mmsc')

                        # Ignore MMS contexts with no MMSC since it is needed
                        # to send messages
                        if mmsc is None or not mmsc.text:
                            continue

                        context['mmsc'] = mmsc.text

                        mmsproxy = apn.find('mmsproxy')
                        if mmsproxy is not None and mmsproxy.text:
                            context['mmsproxy'] = mmsproxy.text

                    username = apn.find('username')
                    if username is not None and username.text:
                        context['username'] = username.text

                    password = apn.find('password')
                    if password is not None and password.text:
                        context['password'] = password.text

                    authentication = apn.find('authentication')
                    if authentication is not None:
                        context['authentication'] = authentication.get('method')

                    context_name = apn.find('name')
                    if context_name != None:
                        context['name'] = context_name.text

                    info.add_context(context)

                if info.is_valid():
                    providers.append(info)

        except ET.ParseError as e:
            print("Error parsing XML: " + e)

        return providers

class ProvisionContext(ctypes.LittleEndianStructure):
    _pack_ = 1
    _fields_ = [
        ('type', ctypes.c_uint32),
        ('protocol', ctypes.c_uint32),
        ('authentication', ctypes.c_uint32),
        ('reserved', ctypes.c_uint32),
        ('name_offset', ctypes.c_uint64),
        ('apn_offset', ctypes.c_uint64),
        ('username_offset', ctypes.c_uint64),
        ('password_offset', ctypes.c_uint64),
        ('mmsproxy_offset', ctypes.c_uint64),
        ('mmsc_offset', ctypes.c_uint64),
        ('tags_offset', ctypes.c_uint64),
    ]

    authentication_dict = { 'chap' : 0, 'pap' : 1, 'none' : 2 }
    protocol_dict = { 'ipv4' : 0, 'ipv6' : 1, 'ipv4v6' : 2 }
    attrs = ['name', 'apn', 'username', 'password', 'mmsproxy', 'mmsc', 'tags']

    @classmethod
    def type_to_context_type(cls, types):
        r = 0

        for t in types:
            if t == 'internet':
                r |= 0x0001
            elif t == 'mms':
                r |= 0x0002
            elif t == 'wap':
                r |= 0x0004
            elif t == 'ims':
                r |= 0x0008
            elif t == 'supl':
                r |= 0x0010
            elif t == 'ia':
                r |= 0x0020

        return r

    def __init__(self, apn, strings):
        self.type = self.type_to_context_type(apn['type'])
        self.protocol = self.protocol_dict[apn.get('protocol', 'ipv4v6')]
        self.authentication = self.authentication_dict[apn.get('authentication',
                                                               'chap')]

        for s in self.attrs:
            offset = strings.add_string(apn.get(s, None))
            setattr(self, s + '_offset', offset)

class ProvisionData(ctypes.LittleEndianStructure):
    _pack_ = 1
    _fields_ = [
        ('spn_offset', ctypes.c_uint64),
        ('context_offset', ctypes.c_uint64)
    ]

    def __init__(self, spn, offset, strings):
        self.spn_offset = strings.add_string(spn)
        self.context_offset = offset

class ProvisionNode(ctypes.LittleEndianStructure):
    _pack_ = 1
    _fields_ = [
        ('bit_offsets', ctypes.c_uint64 * 2),
        ('mccmnc', ctypes.c_uint32),
        ('diff', ctypes.c_int32),
        ('provision_data_count', ctypes.c_uint64)
    ]

    style = "bold"
    fmt_connection = '\t"%s/%d" -> "%s/%d"[color="#%06x"];\n'
    fmt_declaration = '\t"%s/%d"[style=%s, color="#%06x"];\n'
    red = 0xff0000
    green = 0x00ff00

    def __init__(self, key, diff):
        self.bit = [None, None]
        self.key = key
        self.diff = diff
        self.entries = {}
        self.node_offset = 0

    def choose(self, key):
        return (key >> (31 - self.diff)) & 1

    def print_graphviz(self, f):
        f.write(self.fmt_declaration % (format(self.key, '032b'),
                                        self.diff, self.style,
                                        random.randint(0, 0x00ffffff)))
        f.write(self.fmt_connection % (format(self.key, '032b'), self.diff,
                                       format(self.bit[0].key, '032b'),
                                       self.bit[0].diff, self.red))
        f.write(self.fmt_connection % (format(self.key, '032b'), self.diff,
                                       format(self.bit[1].key, '032b'),
                                       self.bit[1].diff, self.green))

        if (self.diff < self.bit[0].diff):
            self.bit[0].print_graphviz(f)

        if (self.diff < self.bit[1].diff):
            self.bit[1].print_graphviz(f)

    def __str__(self):
        s = format(self.key, '032b') + '/' + str(self.diff)
        return s

class MccMncTree:
    @staticmethod
    def clz(v):
        count = 32
        while count and v:
            v = v >> 1
            count = count - 1

        return count

    @staticmethod
    def diff(key1, key2):
        xor = key1 ^ key2;
        return MccMncTree.clz(xor)

    def __init__(self):
        self.root = ProvisionNode(key = 0, diff = -1)
        self.root.bit[0] = self.root
        self.root.bit[1] = self.root
        self.n_nodes = 1

    def print_graphviz(self):
        f = open("step%d.dot" % self.n_nodes, "w")
        # Use 'dot -Tx11' to visualize
        f.write('digraph trie {\n')
        self.root.print_graphviz(f)
        f.write('}\n')
        f.close()

    def find_closest(self, key):
        parent = self.root
        child = self.root.bit[0]

        while parent.diff < child.diff:
            parent = child
            child = child.bit[child.choose(key)]

        return child

    def find(self, key):
        found = self.find_closest(key)
        if found.key == key:
            return found

        return None

    def insert(self, key, attr, value):
        node = self.find_closest(key);
        if node.key == key:
            node.entries[attr] = value
            return

        bit = self.diff(node.key, key)
        parent = self.root
        child = self.root.bit[0]

        while (parent.diff < child.diff) and (child.diff < bit):
            parent = child
            child = child.bit[child.choose(key)]

        node = ProvisionNode(key, bit)
        bit = node.choose(key)
        node.bit[bit] = node
        node.bit[not bit] = child

        node.entries[attr] = value

        if parent == self.root:
            self.root.bit[0] = node
        else:
            bit = parent.choose(key)
            parent.bit[bit] = node

        self.n_nodes += 1

    def traverse_recursive(self, node, bit, visitor):
        if node == self.root:
            return

        if node.diff <= bit:
            visitor.visit(node)
            return

        self.traverse_recursive(node.bit[0], node.diff, visitor)
        self.traverse_recursive(node.bit[1], node.diff, visitor)

    def traverse(self, visitor):
        self.traverse_recursive(self.root.bit[0], -1, visitor)

class StringAccumulator:
    def __init__(self):
        self.data = bytearray(b'\x00') # So offsets are never 0 used for NULL
        self.offsets = {}

    def add_string(self, s):
        if s is None:
            return 0

        if s in self.offsets:
            return self.offsets[s]

        offset = len(self.data)
        self.data.extend(s.encode('utf-8'))
        self.data.append(0)
        self.offsets[s] = offset

        return offset

    def get_bytes(self):
        return self.data

class ProvisionDatabase(ctypes.LittleEndianStructure):
    _pack_ = 1
    _fields_ = [
        ('version', ctypes.c_uint64),
        ('file_size', ctypes.c_uint64),
        ('header_size', ctypes.c_uint64),
        ('node_struct_size', ctypes.c_uint64),
        ('provision_data_struct_size', ctypes.c_uint64),
        ('context_struct_size', ctypes.c_uint64),
        ('nodes_offset', ctypes.c_uint64),
        ('nodes_size', ctypes.c_uint64),
        ('contexts_offset', ctypes.c_uint64),
        ('contexts_size', ctypes.c_uint64),
        ('strings_offset', ctypes.c_uint64),
        ('strings_size', ctypes.c_uint64)
    ]

    class CalculateNodeOffsetVisitor:
        def __init__(self):
            self.current_offset = 0
        def visit(self, node):
            node.node_offset = self.current_offset

            # Node data is followed by at least one ProvisionData object, with
            # the only exception being root, which has no data by definition
            self.current_offset += ctypes.sizeof(ProvisionNode)
            self.current_offset += (ctypes.sizeof(ProvisionData) *
                                    len(node.entries))

    class SerializeVisitor:
        def __init__(self, buffer):
            self.buffer = buffer

        def visit(self, node):
            # Node doesn't quite fit the C structure definition, so do this
            # manually by using struct.pack
            self.buffer.extend(struct.pack('<QQIiQ',
                                           node.bit[0].node_offset,
                                           node.bit[1].node_offset,
                                           node.key, node.diff,
                                           len(node.entries)))

            for spn in sorted(node.entries):
                pd = node.entries[spn]
                self.buffer.extend(bytes(pd))

    def __init__(self, provider_infos):
        self.strings = StringAccumulator()
        self.contexts = bytearray()
        self.tree = MccMncTree()

        for info in provider_infos:
            pd = ProvisionData(info.spn, len(self.contexts), self.strings)

            self.contexts.extend(struct.pack('<Q', len(info.context_list)))

            for context in info.context_list:
                self.contexts.extend(bytes(ProvisionContext(context,
                                                            self.strings)))

            for mccmnc in info.mccmnc_list:
                # Sort None spns as '' so they're first in the list when
                # the SerializeVisitor sorts the entries dict
                spn = info.spn if info.spn is not None else ''

                # 2 and 3 byte MNCs are treated differently, even if evaluate
                # to the same integer.  For example, 02 and 002 are different
                # MNCs.  In practice this doesn't actually happen except on
                # test networks, but account for this possibility by using the
                # upper 10 bits for the MCC, followed by a single bit which
                # signifies whether a 3 byte MNC is used, followed by 10 bits
                # of the MNC
                key = int(mccmnc[:3]) << 11 | int(mccmnc[3:])
                if len(mccmnc[3:]) == 3:
                    key |= 1 << 10

                self.tree.insert(key, spn, pd)

        visitor = self.CalculateNodeOffsetVisitor()
        visitor.visit(self.tree.root)
        self.tree.traverse(visitor)

        self.version = 2
        self.header_size = ctypes.sizeof(ProvisionDatabase)
        self.file_size = self.header_size
        self.node_struct_size = ctypes.sizeof(ProvisionNode)
        self.provision_data_struct_size = ctypes.sizeof(ProvisionData)
        self.context_struct_size = ctypes.sizeof(ProvisionContext)
        self.nodes_offset = self.header_size
        self.nodes_size = visitor.current_offset
        self.file_size += self.nodes_size
        self.contexts_offset = self.nodes_offset + self.nodes_size
        self.contexts_size = len(self.contexts)
        self.file_size += self.contexts_size
        self.strings_offset = self.contexts_offset + self.contexts_size
        self.strings_size = len(self.strings.get_bytes())
        self.file_size += self.strings_size

    def serialize(self):
        buffer = bytearray()
        buffer.extend(bytes(self))

        visitor = self.SerializeVisitor(buffer)
        visitor.visit(self.tree.root)
        self.tree.traverse(visitor)

        buffer.extend(self.contexts)
        buffer.extend(self.strings.get_bytes())

        return buffer

class ProviderInfoJSONEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, ProviderInfo):
            asdict = { 'name' : obj.name, 'ids' : obj.mccmnc_list }

            if (obj.spn != None):
                asdict['spn'] = obj.spn

            asdict.update({'apns' : obj.context_list})
            return asdict

        return json.JSONEncoder.default(self, obj)

def mbpi_convert(args):
    xml_path = '/usr/share/mobile-broadband-provider-info/serviceproviders.xml'
    mbpi = MobileBroadbandProviderInfo(xml_path)
    provider_infos = mbpi.parse(xml_path)

    try:
        if args.outfile is None:
            out = sys.stdout
        else:
            out = args.outfile.open('w', encoding='utf-8')

        with out as outfile:
            json.dump(provider_infos, outfile, ensure_ascii=False, indent=2,
                      cls=ProviderInfoJSONEncoder)
    except ValueError as e:
        raise SystemExit(e)

def generate(args):
    try:
        json_dict = json.load(args.infile)
    except ValueError as e:
        raise SystemExit(e)

    provider_infos = []

    for entry in json_dict:
        info = ProviderInfo.rawimport(entry)
        provider_infos.append(info)

    db = ProvisionDatabase(provider_infos)
    with args.outfile.open('wb') as outfile:
        outfile.write(db.serialize())

def selftest(args):
    tree = MccMncTree()

    # Generate random key this many times and insert it into the tree
    # There will be some collisions, so count the number of items inserted
    # Then run the lookup and make sure the same number of items can be found
    times = 100000
    for i in range(0, times):
        key = random.randint(10000, 999999)
        tree.insert(key, None, None)

    n_inserted = tree.n_nodes
    print("Created a tree with %d nodes (1 root)" % n_inserted)

    n_found = 0
    expected_keys = []

    for i in range(10000, 1000000):
        if tree.find(i):
            n_found += 1
            expected_keys.append(i)

    expected_keys = sorted(expected_keys)

    print("Found %d nodes (not including root)" % n_found)
    assert n_found == n_inserted - 1

    class GetKeysVisitor:
        def __init__(self):
            self.keys = []

        def visit(self, node):
            self.keys.append(node.key)

    visitor = GetKeysVisitor()
    tree.traverse(visitor)
    assert visitor.keys == expected_keys

    sample_json = """[
    {
      "name": "Operator XYZ",
      "ids": [
        "99955", "99956", "99901", "99902"
      ],
      "apns": [
        {
          "name": "Internet",
          "apn": "internet",
          "type": [
            "internet"
          ],
          "authentication": "none",
          "protocol": "ipv4"
        },
        {
          "name": "IMS+MMS",
          "apn": "imsmms",
          "type": [
            "ims", "mms"
          ],
          "mmsc": "foobar.mmsc:80",
          "mmsproxy": "mms.proxy.net",
          "authentication": "pap",
          "protocol": "ipv6"
        }
      ]
    },
    {
      "name": "Operator ZYX",
      "ids": [
        "99998", "99999", "99901", "99902"
      ],
      "spn": "ZYX",
      "apns": [
        {
          "name": "ZYX",
          "apn": "zyx",
          "type": [
            "internet"
          ],
          "authentication": "none",
          "protocol": "ipv4"
        }
      ]
    }
    ]"""
    try:
        json_dict = json.loads(sample_json)
    except ValueError as e:
        raise SystemExit(e)

    provider_infos = []

    for entry in json_dict:
        info = ProviderInfo.rawimport(entry)
        provider_infos.append(info)

    db = ProvisionDatabase(provider_infos)
    expected_strings = bytearray(b'\x00Internet\x00internet\x00IMS+MMS\x00' +
                b'imsmms\x00mms.proxy.net\x00foobar.mmsc:80\x00ZYX\x00zyx\x00')
    assert(expected_strings == db.strings.get_bytes())

    with open('sample.db', 'wb') as outfile:
        outfile.write(db.serialize())

if __name__ == "__main__":
    parser = ArgumentParser(description='Parse command line arguments')
    subparsers = parser.add_subparsers(title='subcommands', dest='command')

    # mbpi-convert command
    mbpi_convert_parser = subparsers.add_parser('mbpi-convert',
                        help='Convert mobile-broadband-provider-info database')
    mbpi_convert_parser.add_argument('--outfile', type=Path, default=None,
                        help='Output file path', required=False,)
    mbpi_convert_parser.set_defaults(func=mbpi_convert)

    # generate command
    generate_parser = subparsers.add_parser('generate',
                                            help='Generate binary provider db')
    generate_parser.add_argument('--outfile', type=Path, default='provision.db',
                                 help='Output file path', required=False)
    generate_parser.add_argument('--infile', type=FileType(encoding='utf-8'),
                                 help='Input JSON db', default=sys.stdin)
    generate_parser.set_defaults(func=generate)

    # selftest command
    selftest_parser = subparsers.add_parser('selftest',
                        help='Run self-tests')
    selftest_parser.set_defaults(func=selftest)

    args = parser.parse_args()
    if hasattr(args, 'func'):
        args.func(args)
    else:
        parser.print_help()
