#!/usr/bin/python3

#
# Copyright (C) 2025 Nethesis S.r.l.
# SPDX-License-Identifier: GPL-2.0-only
#

from euci import EUci
from requests import HTTPError
from requests.auth import HTTPBasicAuth
import hashlib
import ipaddress
import requests
import nethsec.utils
import nethsec.objects
import json
import subprocess
import logging
import random
import signal
from threading import Event
import resource

FLASHSTART_API_ENDPOINT = 'https://api.flashstart.com/1.0.0'

DEFAULT_HEADERS = {
    'Content-Type': 'application/json',
    'Accept': 'application/json',
}

CONST_PORT_START = 5300

# this wrapper is more hideous than it looks. UCI config gets read ONLY ONCE WHEN INSTANTIATED, beware of that.
e_uci = EUci()

def __expand_ip_range(entry: str) -> list:
    """Expand a dash-separated IP range (e.g. '10.0.0.1-10.0.0.5') into individual IP strings.
    Single IPs and CIDR networks are returned as-is in a one-element list."""
    logging.debug(f'Expanding IP entry {entry!r}')
    if '-' in entry:
        parts = entry.split('-', 1)
        try:
            start = ipaddress.ip_address(parts[0].strip())
            end = ipaddress.ip_address(parts[1].strip())
            return [str(ipaddress.ip_address(ip)) for ip in range(int(start), int(end) + 1)]
        except ValueError:
            logging.warning(f'Invalid IP range entry: {entry!r}, keeping as-is')
    return [entry]


def __load_bypass_rules() -> list:
    rules = []
    for entry in e_uci.get('flashstart', 'global', 'bypass', list=True, dtype=str, default=[]):
        if entry == '':
            continue
        for expanded in __expand_ip_range(entry):
            try:
                rules.append(ipaddress.ip_network(expanded, strict=False))
            except ValueError:
                try:
                    rules.append(ipaddress.ip_address(expanded))
                except ValueError:
                    logging.warning(f'Ignoring invalid bypass entry: {expanded!r}')
    return rules


def __is_bypassed_entry(entry: str, bypass_rules: list) -> bool:
    try:
        addr = ipaddress.ip_address(entry)
    except ValueError:
        return False

    for rule in bypass_rules:
        if isinstance(rule, (ipaddress.IPv4Network, ipaddress.IPv6Network)) and addr in rule:
            return True
        if isinstance(rule, (ipaddress.IPv4Address, ipaddress.IPv6Address)) and addr == rule:
            return True
    return False


def __refresh_uci():
    global e_uci
    e_uci = EUci()


def __get_client() -> requests.Session:
    session = requests.Session()
    session.headers.update(DEFAULT_HEADERS)
    username = e_uci.get('flashstart', 'global', 'username', default='')
    password = e_uci.get('flashstart', 'global', 'password', default='')
    hashed_password = hashlib.md5(password.encode()).hexdigest()
    session.auth = HTTPBasicAuth(username, hashed_password)
    return session


def check_credentials(username, password):
    response = requests.post(f'{FLASHSTART_API_ENDPOINT}/auth/check/', auth=HTTPBasicAuth(username, password),
                             data='{"who": "Device"}',
                             headers=DEFAULT_HEADERS)
    response.raise_for_status()


def disable():
    pending_changes = __check_pending_changes('dhcp') or __check_pending_changes('firewall')

    for config in ['dhcp', 'firewall']:
        for entry in e_uci.get(config, list=True, dtype=str, default=[]):
            if e_uci.get(config, entry, 'ns_flashstart', default=False, dtype=bool):
                # delete the entry
                logging.debug(f'Deleting {config} entry {entry}')
                e_uci.delete(config, entry)

    __save(pending_changes)


def __fetch_instanced_services(config, kind):
    local_instances = []
    for instance in e_uci.get(config, list=True, dtype=str, default=[]):
        if e_uci.get(config, instance) == kind and e_uci.get(config, instance, 'ns_flashstart', default=False,
                                                             dtype=bool):
            local_instances.append(instance)
    return local_instances


def __fetch_local_dhcp_instances():
    # generate the list of dnsmasq instances managed by flashstart
    local_instances = []
    for dns in e_uci.get('dhcp', list=True, dtype=str, default=[]):
        if e_uci.get('dhcp', dns, 'ns_flashstart', default=None) is not None:
            local_instances.append(dns)
    return local_instances


def __fetch_local_redirect_instances():
    # generate the list of dns redirect for flashstart
    local_redirects = []
    for redirect in e_uci.get('firewall', list=True, dtype=str, default=[]):
        if e_uci.get('firewall', redirect) == 'redirect' and e_uci.get('firewall', redirect, 'ns_flashstart',
                                                                       default=False, dtype=bool):
            local_redirects.append(redirect)

    return local_redirects


def __check_pending_changes(config: str):
    result = subprocess.run(['ubus', 'call', 'uci', 'changes', json.dumps({'config': config})], check=True,
                            capture_output=True, text=True)
    changes = json.loads(result.stdout)
    return 'changes' in changes and len(changes['changes']) > 0


def __sync_host_sets():
    data = []
    for host_set in e_uci.get('objects', list=True, dtype=str, default=[]):
        if e_uci.get('objects', host_set) != 'host':
            continue
        entries = ';'.join(nethsec.objects.get_object_ips(e_uci, f'objects/{host_set}'))
        logging.debug(f'Syncing host set {host_set} with entries {entries}')
        data.append({
            'name': host_set,
            'username': e_uci.get('objects', host_set, 'name', default=''),
            'surname': entries,
        })
    client = __get_client()
    username = e_uci.get('flashstart', 'global', 'username', default='')
    response = client.post(f'{FLASHSTART_API_ENDPOINT}/sync/objects/{username}/1', json.dumps(data))
    response.raise_for_status()
    logging.debug(f'Successfully synced {len(data)} host sets to flashstart.')


def __save(pending_changes: bool):
    e_uci.save('dhcp')
    e_uci.save('firewall')
    pending_dhcp = __check_pending_changes('dhcp')
    pending_firewall = __check_pending_changes('firewall')
    if pending_dhcp or pending_firewall:
        if pending_changes:
            # don't commit if there are pending changes
            logging.warning('Some changes are pending to dhcp/firewall configuration, skipping commit.')
        else:
            # commit the changes to uci, this will trigger all reloads
            logging.info(f'Applying changes to dhcp/firewall configuration.')
            if pending_dhcp:
                subprocess.run(['ubus', 'call', 'uci', 'commit', json.dumps({'config': 'dhcp'})], check=True)
            if pending_firewall:
                subprocess.run(['ubus', 'call', 'uci', 'commit', json.dumps({'config': 'firewall'})], check=True)
                subprocess.run(['fw4', 'reload-sets'], capture_output=True)


def __add_bypass_ipset():
    if e_uci.get('firewall', 'ns_flashstart_bypass', default=None) is None:
        logging.debug(f'Creating bypass set for flashstart')
        # create the bypass set
        e_uci.set('firewall', 'ns_flashstart_bypass', 'ipset')
        e_uci.set('firewall', 'ns_flashstart_bypass', 'ns_flashstart', True)
        e_uci.set('firewall', 'ns_flashstart_bypass', 'ns_tag', ['automated'])
    e_uci.set('firewall', 'ns_flashstart_bypass', 'name', 'flashstart-bypass')
    e_uci.set('firewall', 'ns_flashstart_bypass', 'enabled', 1)
    e_uci.set('firewall', 'ns_flashstart_bypass', 'match', 'net')
    e_uci.set('firewall', 'ns_flashstart_bypass', 'family', 'inet')
    # fetch the bypass list from flashstart
    bypass_list = sorted([entry for entry in e_uci.get('flashstart', 'global', 'bypass', list=True, dtype=str, default=[])
                   if entry != ''])
    # check if ipset entries are the same
    stored_entries = sorted(list(e_uci.get('firewall', 'ns_flashstart_bypass', 'entry', list=True, dtype=str, default=[])))
    if stored_entries != bypass_list:
        e_uci.set('firewall', 'ns_flashstart_bypass', 'entry', bypass_list)


def __sync_pro_profiles():
    # check if some changes are pending
    pending_changes = __check_pending_changes('dhcp') or __check_pending_changes('firewall')

    # get already existing dnsmasq instances managed by flashstart
    local_instances = __fetch_local_dhcp_instances()

    # fetch config
    config = __fetch_config()

    # use just the first, since we don't support multiple profiles in pro setup
    first_server = config['dhcp'][0]
    __add_profile(CONST_PORT_START, first_server)
    # remove every other profile other than the first
    for local_instance in local_instances:
        if local_instance != first_server['id']:
            logging.info(f'Profile {local_instance} not present in flashstart, deleting instance...')
            e_uci.delete('dhcp', local_instance)

    # generate bypass firewall set
    __add_bypass_ipset()

    # setup catch all DNAT rule for each zone
    local_redirects = __fetch_local_redirect_instances()
    applied_zones = []
    for zone in e_uci.get('flashstart', 'global', 'zones', default=[], list=True, dtype=str):
        redirect_id = f'ns_flashstart_{zone}'
        if e_uci.get('firewall', redirect_id, default=None) is None:
            # create the redirect if not present
            logging.debug(f'Creating new redirect {redirect_id}')
            e_uci.set('firewall', redirect_id, 'redirect')
            e_uci.set('firewall', redirect_id, 'ns_flashstart', True)
            e_uci.set('firewall', redirect_id, 'ns_tag', ['automated'])
        e_uci.set('firewall', redirect_id, 'name', f'Flashstart-intercept-DNS-from-{zone}')
        e_uci.set('firewall', redirect_id, 'src', zone)
        e_uci.set('firewall', redirect_id, 'src_dport', 53)
        e_uci.set('firewall', redirect_id, 'proto', "tcp udp")
        e_uci.set('firewall', redirect_id, 'target', 'DNAT')
        e_uci.set('firewall', redirect_id, 'ipset', f'!flashstart-bypass')
        e_uci.set('firewall', redirect_id, 'dest_port', CONST_PORT_START)
        applied_zones.append(redirect_id)
    # remove every other redirect other than the ones in the list
    for local_redirect in local_redirects:
        if local_redirect not in applied_zones:
            logging.info(f'Redirect {local_redirect} not present in flashstart, deleting instance...')
            e_uci.delete('firewall', local_redirect)

    __save(pending_changes)


def __add_profile(port, profile):
    query_logging = e_uci.get('flashstart', 'global', 'logqueries', default=False, dtype=bool)
    rebind_protection = e_uci.get('flashstart', 'global', 'rebind_protection', default=False, dtype=bool)
    if e_uci.get('dhcp', profile['id'], default=None) is None:
        logging.info(f'New profile found {profile["name"]}, creating instance {profile["id"]}.')
        e_uci.set('dhcp', profile['id'], 'dnsmasq')
        e_uci.set('dhcp', profile['id'], 'ns_flashstart', True)
        e_uci.set('dhcp', profile['id'], 'ns_tag', ['automated'])
    e_uci.set('dhcp', profile['id'], 'ns_flashstart_profile', profile["name"])
    e_uci.set('dhcp', profile['id'], 'ns_flashstart_dns_code', profile["dns_code"])
    e_uci.set('dhcp', profile['id'], 'port', port)
    e_uci.set('dhcp', profile['id'], 'noresolv', True)
    e_uci.set('dhcp', profile['id'], 'logqueries', query_logging)
    e_uci.set('dhcp', profile['id'], 'rebind_protection', rebind_protection)
    # sets always get replaced, checking before replacing
    custom_servers = list(e_uci.get('flashstart', 'global', 'custom_servers', list=True, dtype=str, default=[]))
    servers = sorted(custom_servers + profile['servers'])
    stored_entries = sorted(list(e_uci.get('dhcp', profile['id'], 'server', list=True, dtype=str, default=[])))
    if stored_entries != servers:
        logging.debug(f'Replacing dns servers for {profile["id"]} with {servers}')
        e_uci.set('dhcp', profile['id'], 'server', servers)


def __sync_pro_plus_profiles():
    # check if some changes are pending
    pending_changes = __check_pending_changes('dhcp') or __check_pending_changes('firewall')
    # get already existing dnsmasq instances managed by flashstart
    dhcp_instances = __fetch_local_dhcp_instances()
    ip_set_instances = __fetch_instanced_services('firewall', 'ipset')
    redirect_instances = __fetch_instanced_services('firewall', 'redirect')
    added_redirects = []

    # fetch config
    config = __fetch_config()
    remote_ids = [profile['id'] for profile in config['dhcp']]
    starting_port = CONST_PORT_START

    # generate bypass firewall set
    __add_bypass_ipset()
    added_ip_sets = ['ns_flashstart_bypass']
    
    # sort the profiles by catch-all, so that the catch-all profile is always the last one
    dhcp_entries = sorted(config['dhcp'], key=lambda entry: entry['catch-all'])
    for profile in dhcp_entries:
        __add_profile(starting_port, profile)
        # generate IPSet for each profile
        if not profile['catch-all']:
            ip_set_id = f'ns_flashstart_ipset_{profile["id"]}'
            if e_uci.get('firewall', ip_set_id, default=None) is None:
                logging.debug(f'Creating ipset for profile {profile["id"]}')
                e_uci.set('firewall', ip_set_id, 'ipset')
                e_uci.set('firewall', ip_set_id, 'ns_flashstart', True)
                e_uci.set('firewall', ip_set_id, 'ns_tag', ['automated'])
            e_uci.set('firewall', ip_set_id, 'name', f'flashstart-ipset-{profile["id"]}')
            e_uci.set('firewall', ip_set_id, 'enabled', 1)
            e_uci.set('firewall', ip_set_id, 'match', 'net')
            e_uci.set('firewall', ip_set_id, 'family', 'inet')
            e_uci.set('firewall', ip_set_id, 'ns_flashstart_dns_code', profile['dns_code'])
            # remote event handling is really in bad shape, this value is just a reference for easier retrieval
            e_uci.set('firewall', ip_set_id, 'ns_flashstart_dns', ','.join(profile['servers']))
            added_ip_sets.append(ip_set_id)
        # generate redirect rule for each profile and zone
        for zone in e_uci.get('flashstart', 'global', 'zones', default=[], list=True, dtype=str):
            redirect_id = f'ns_flashstart_{zone}_{profile["id"]}'
            if e_uci.get('firewall', redirect_id, default=None) is None:
                logging.debug(f'Creating new redirect {redirect_id}')
                e_uci.set('firewall', redirect_id, 'redirect')
                e_uci.set('firewall', redirect_id, 'ns_flashstart', True)
                e_uci.set('firewall', redirect_id, 'ns_tag', ['automated'])
            e_uci.set('firewall', redirect_id, 'src', zone)
            e_uci.set('firewall', redirect_id, 'src_dport', 53)
            e_uci.set('firewall', redirect_id, 'dest_port', starting_port)
            e_uci.set('firewall', redirect_id, 'proto', "tcp udp")
            e_uci.set('firewall', redirect_id, 'target', 'DNAT')
            if profile['catch-all']:
                e_uci.set('firewall', redirect_id, 'name', f'Flashstart-catch-all-{zone}-{profile["id"]}')
                e_uci.set('firewall', redirect_id, 'ipset', f'!flashstart-bypass')
            else:
                e_uci.set('firewall', redirect_id, 'name', f'Flashstart-intercept-DNS-from-{zone}-{profile["id"]}')
                e_uci.set('firewall', redirect_id, 'ipset', f'flashstart-ipset-{profile["id"]}')
            added_redirects.append(redirect_id)
        starting_port += 1

    # apply rule ordering: profile rules first, then the catch-all last
    e_uci.save('firewall')
    index = 0
    for redirect in added_redirects:
        subprocess.run(["uci", "reorder", f"firewall.{redirect}={index}"], capture_output=True)
        index += 1

    active_sessions = dict()
    for session in config['sessions']:
        entries = []
        if session['ip'] is None:
            # syncing an object
            object_id = None
            for entry in nethsec.objects.list_objects(e_uci, False):
                if entry['name'] == session['username']:
                    object_id = entry['id']
            if object_id is None:
                logging.warning(f'Object {session["username"]} not found, skipping session sync.')
                continue
            for ip in nethsec.objects.get_object_ips(e_uci, object_id):
                entries.extend(__expand_ip_range(ip))
        else:
            # syncing a user login
            entries.append(session['ip'])

        if session['dns_code'] not in active_sessions:
            active_sessions[session['dns_code']] = []
        active_sessions[session['dns_code']].extend(entries)

    bypass_rules = __load_bypass_rules()

    for ip_set in __fetch_instanced_services('firewall', 'ipset'):
        if ip_set == 'ns_flashstart_bypass':
            continue
        ipset_dns_code = e_uci.get('firewall', ip_set, 'ns_flashstart_dns_code', default='')
        if ipset_dns_code in active_sessions:
            upstream_entries = sorted(list(set(active_sessions.get(ipset_dns_code))))
            upstream_entries = [entry for entry in upstream_entries if not __is_bypassed_entry(entry, bypass_rules)]
        else:
            upstream_entries = []
        # check if ipset entries are the same
        stored_entries = sorted(list(e_uci.get('firewall', ip_set, 'entry', list=True, dtype=str, default=[])))
        if stored_entries != upstream_entries:
            logging.debug(f'Updating ipset {ip_set} with entries {upstream_entries}')
            e_uci.set('firewall', ip_set, 'entry', upstream_entries)

    # remove every other profile other than the ones in the list
    for instance in dhcp_instances:
        if instance not in remote_ids:
            logging.info(f'Profile {instance} not present in flashstart, deleting instance...')
            e_uci.delete('dhcp', instance)
    for instance in ip_set_instances:
        if instance not in added_ip_sets:
            logging.info(f'IPSet {instance} not present in flashstart, deleting instance...')
            e_uci.delete('firewall', instance)
    for instance in redirect_instances:
        if instance not in added_redirects:
            logging.info(f'Redirect {instance} not present in flashstart, deleting instance...')
            e_uci.delete('firewall', instance)

    __save(pending_changes)


def sync():
    if e_uci.get('flashstart', 'global', 'proplus', default=False, dtype=bool):
        logging.debug('Syncing pro+ profiles')
        __sync_host_sets()
        __sync_pro_plus_profiles()
    else:
        logging.debug('Syncing pro profiles')
        __sync_pro_profiles()


def __fetch_config():
    client = __get_client()
    username = e_uci.get('flashstart', 'global', 'username', default='')
    profiles = client.get(f'{FLASHSTART_API_ENDPOINT}/network/profiles/{username}')
    profiles.raise_for_status()
    data = {'dhcp': [], 'sessions': []}
    for profile in profiles.json():
        data['dhcp'].append({
            'id': nethsec.utils.get_id(hashlib.sha256(profile['name'].encode()).hexdigest()[:10]),
            'name': profile['name'],
            'catch-all': profile['is_catch-all'],
            'dns_code': profile['dns_code'],
            'servers': [
                profile['dns']['ipv4']['primary_dns'],
                profile['dns']['ipv4']['secondary_dns']
            ]
        })
    sessions = client.get(f'{FLASHSTART_API_ENDPOINT}/sessions/list/{username}')
    sessions.raise_for_status()
    for session in sessions.json():
        data['sessions'].append({
            'username': session['username'],
            'ip': session['ip'],
            'dns_code': session['dns_code'],
            'dns': ','.join([session['dns']['primary_dns'], session['dns']['secondary_dns']]),
        })
    # jobs will be appended into a list, since it has a different format we won't use it for now
    # however still needs to be fetched to reset the pending job counter
    client.post(f'{FLASHSTART_API_ENDPOINT}/jobs/{username}')

    return data


def ping():
    wan_devices = set()
    res = subprocess.run(["/sbin/fw4", "zone", "wan"], capture_output=True, text=True)
    if res.returncode != 0:
        raise Exception(f'Failed to get WAN devices with error: {res.stdout}')
    wan_devices.update([line.strip() for line in res.stdout.splitlines() if line.strip()])
    user = e_uci.get('flashstart', 'global', 'username', default='')
    password = e_uci.get('flashstart', 'global', 'password', default='')
    counter = 0
    for wan in sorted(wan_devices):
        if counter > 0:
            wan_id = f"{user}-{counter}"
        else:
            wan_id = user
        url = f"https://ddns.flashstart.com/nic/update?hostname=&myip=&wildcard=NOCHG&username={wan_id}&password={password}"
        cmd = ["curl", "-s", "--connect-timeout", "10", "--interface", wan, url]
        res = subprocess.run(cmd, capture_output=True, text=True)
        if res.returncode != 0:
            logging.warning(f'Failed to update IP for WAN {wan} -> {wan_id} with error: {res.stdout}')
        counter = counter + 1


def __update_version():
    # enable catch-all switching from flashstart server
    client = __get_client()
    username = e_uci.get('flashstart', 'global', 'username', default='')
    client.post(f'{FLASHSTART_API_ENDPOINT}/updates/check/{username}', json.dumps({
        'type': '5',
        'sys_type': '11',
        'version': '1.0.2'
    }))


def main():
    import argparse

    parser = argparse.ArgumentParser(description='Nethesis Flashstart helper')
    parser.add_argument('--log-level', default='info', choices=['debug', 'info', 'warning', 'error'],
                        help='Set the log level (default: info)')
    parser.add_argument('--syslog', action='store_true', help='Send logs to syslog')
    subparsers = parser.add_subparsers(dest='command')

    check_credentials_parser = subparsers.add_parser('check-credentials', help='Util to check Flashstart credentials')
    check_credentials_parser.add_argument('--username', required=True)
    check_credentials_parser.add_argument('--password', required=True)

    subparsers.add_parser('sync', help='Sync Flashstart')
    subparsers.add_parser('daemon', help='Run the sync command foreground')
    subparsers.add_parser('disable', help='Remove generated configuration')
    subparsers.add_parser('ping', help='Ping flashstart server with all WANs')

    args = parser.parse_args()
    logger = logging.getLogger()
    logger.setLevel(getattr(logging, args.log_level.upper()))
    handler = logging.StreamHandler()
    logger.addHandler(handler)
    match args.command:
        case 'check-credentials':
            try:
                check_credentials(args.username, hashlib.md5(args.password.encode()).hexdigest())
                parser.exit(message='Credentials are valid.')
            except HTTPError as e:
                if e.response.status_code == 401:
                    parser.exit(1, message='Invalid credentials.')
                else:
                    parser.exit(1, message=f'Error checking credentials: {e}')
        case 'sync':
            try:
                sync()
            except HTTPError as e:
                logging.error(f'Error syncing profiles: {e}')
                parser.exit(1)
            except Exception as e:
                logging.error(f'Unexpected error: {e}')
                parser.exit(1)
        case 'disable':
            # remove config
            disable()
        case 'daemon':
            # launch a process with the sync command once every 30 seconds, with a random delay of max 5 seconds
            # shutoff event handling in case of SIGTERM, SIGINT or SIGHUP
            def shutoff(_, _frame):
                logging.info('Shutting down flashstart sync daemon')
                event_handler.set()
            event_handler = Event()
            signal.signal(signal.SIGTERM, shutoff)

            # run the sync command every 30 +/- random(0, 5) seconds
            wait_time = random.randint(25, 35)
            logging.info('Starting flashstart sync daemon')
            __update_version()
            # this is 9 to trigger the first ping immediately
            ping_timer = 9
            while not event_handler.is_set():
                try:
                    if ping_timer >= 9:
                        ping()
                        ping_timer = 0
                    else:
                        ping_timer += 1
                    sync()
                except HTTPError as e:
                    logging.error(f'Error syncing profiles: {e}')
                except Exception as e:
                    logging.error(f'Unexpected error: {e}')
                event_handler.wait(wait_time)
                # If the application exceeds the max_memory value in size, break the loop
                max_memory = 500000
                if resource.getrusage(resource.RUSAGE_SELF).ru_maxrss > max_memory:
                    logger.warning('Memory usage exceeded 500MB, stopping the daemon.')
                    break
                # re-init the UCI object to avoid issues with the config being changed
                __refresh_uci()
            # remove the config
            disable()
        case 'ping':
            ping()
        case _:
            parser.print_help()

if __name__ == '__main__':
    main()
