#!/usr/bin/python3

#
# Copyright (C) 2022 Nethesis S.r.l.
# SPDX-License-Identifier: GPL-2.0-only
#

import os
import pwd
import grp
import sys
import os.path
import shutil
import nsmigration
import subprocess
import ipaddress
import pyotp
from nethsec import firewall, utils, ovpn

def save_cert(path, data):
    # setup rw permissions only for nobody user
    os.umask(0o077)
    with open(path, 'w') as f:
        f.write(data)
    
    uid = pwd.getpwnam("nobody").pw_uid
    gid = grp.getgrnam("nogroup").gr_gid
    os.chown(path, uid, gid)

def is_bridge(u, name):
    devices = utils.get_all_by_type(u, "network", "device")
    for d in devices:
        device = devices[d]
        if device.get("name") == name and device.get("type") == "bridge":
            return True
    return False

def get_bridge_by_ip(u, ip):
    for section in utils.get_all_by_type(u, "network", "interface"):
        if u.get("network", section, "proto", default="") == "static":
            device = u.get("network", section, "device", default="")
            if is_bridge(u, device) and u.get("network", section, "ipaddr", default="") == ip:
                return section
    return None

def add_tap_to_bridge(u, bridge, interface):
    if bridge is None or interface is None:
        return
    bridge_device = u.get("network", bridge, "device", default=None)
    if not bridge_device:
        return
    devices = utils.get_all_by_type(u, "network", "device")
    for d in devices:
        device = devices[d]
        if device.get("name") == bridge_device and device.get("type") == "bridge":
            try:
                ports = u.get_all("network", d, "ports")
                if interface not in ports:
                    ports = list(ports)
                    ports.append(interface)
                    u.set("network", d, "ports", ports)
                    u.commit("network")
            except:
                pass
            return


iname="ns_roadwarrior1"
cert_dir=f"/etc/openvpn/{iname}/pki/"
auth_script = '/usr/libexec/ns-openvpn/openvpn-local-auth'

(u, data, nmap) = nsmigration.init("openvpn.json")

if not data:
    sys.exit(0)

# Clean existing config
try:
    u.delete('openvpn', iname)
except:
    pass
if os.path.isdir(f'/etc/openvpn/{iname}'):
    shutil.rmtree(f'/etc/openvpn/{iname}')
subprocess.run(["/usr/sbin/ns-openvpnrw-init-pki", iname], check=True)

# Setup global options
u.set("openvpn", iname, "openvpn")
# Migrate from net30 obsolete topology to subnet topology
# This is required to allow queries to local dnsamasq from VPN if
# local-service option is enabled
if data['rw']['options'].get('topology') == 'net30':
    data['rw']['options']['topology' ] = 'subnet'
for option in data['rw']['options']:
    nsmigration.vprint(f"Setting OpenVPN Road Warrior option {option}")
    u.set("openvpn", iname, option, data['rw']['options'][option])

# Force device type and name and server pool
if data['rw']['options'].get('dev', '').startswith("tun"):
    dev_type ='tun'
    if not 'ifconfig_pool' in data['rw']['options'] and 'server' in data['rw']['options'] and data['rw']['options'].get("server"):
        # set default ifconfig_pool
        (server_net, server_mask) = data['rw']['options']['server'].split(" ")
        start_ip = ipaddress.ip_address(server_net) + 2
        end_ip = ipaddress.ip_address(server_net) + 254
        u.set("openvpn", iname, "ifconfig_pool", f"{start_ip} {end_ip} {server_mask}")
        u.set("openvpn", iname, "server", f"{data['rw']['options']['server']} nopool")
else:
    dev_type ='tap'

u.set("openvpn", iname, 'dev_type', dev_type)
if dev_type == 'tap':
    u.set("openvpn", iname, 'dev', 'taprw1')
    ip = data['rw']['options'].pop("server_bridge").split(" ", 1)
    ns_bridge = get_bridge_by_ip(u, ip[0])
    if ns_bridge:
        u.set("openvpn", iname, 'ns_bridge', ns_bridge)
        add_tap_to_bridge(u, ns_bridge, 'taprw1')
else:
    u.set("openvpn", iname, 'dev', 'tunrw1')

# Setup server certificates
save_cert(os.path.join(cert_dir, 'ca.crt'), data['rw']['ca']['ca.crt'])
save_cert(os.path.join(cert_dir, 'index'), data['rw']['ca']['certindex'])
save_cert(os.path.join(cert_dir, 'index.attr'), data['rw']['ca']['certindex.attr'])
save_cert(os.path.join(cert_dir, 'crl.pem'), data['rw']['ca']['crl.pem'])
save_cert(os.path.join(cert_dir, 'serial'), data['rw']['ca']['serial'])
save_cert(os.path.join(cert_dir, 'issued/server.crt'), data['rw']['ca']['srv.crt'])
save_cert(os.path.join(cert_dir, 'private/server.key'), data['rw']['ca']['srv.key'])
save_cert(os.path.join(cert_dir, 'private/ca.key'), data['rw']['ca']['srv.key'])
u.set("openvpn", iname, 'dh', os.path.join(cert_dir, 'dh.pem'))
u.set("openvpn", iname, 'ca', os.path.join(cert_dir, 'ca.crt'))
u.set("openvpn", iname, 'cert', os.path.join(cert_dir, 'issued/server.crt'))
u.set("openvpn", iname, 'key', os.path.join(cert_dir, 'private/server.key'))
u.set("openvpn", iname, 'key', os.path.join(cert_dir, 'private/ca.key'))
u.set("openvpn", iname, 'crl_verify', os.path.join(cert_dir, 'crl.pem'))

u.set("openvpn", iname, 'client_connect', f'"/usr/libexec/ns-openvpn/openvpn-connect {iname}"')
u.set("openvpn", iname, 'client_disconnect', f'"/usr/libexec/ns-openvpn/openvpn-disconnect {iname}"')
u.set("openvpn", iname, 'management', f'/var/run/openvpn_{iname}.socket unix')

u.set('openvpn', iname, 'ns_description', "Migrated")
u.set("openvpn", iname, 'ns_auth_mode', 'certificate')
u.set('openvpn', iname, 'ns_tag', ["automated", "migrated"])

# cleanup user config, keep only main db
for s in u.get_all('users'):
    if u.get('users', s) in ['ldap', 'user']:
        u.delete('users', s)

user_db = "main"
if data['ldap']:
    user_db = "ns_ldap1"
    auth_script = '/usr/libexec/ns-openvpn/openvpn-remote-auth'
    u.set("users", "ns_ldap1", "ldap")
    for o in data['ldap']:
        u.set("users", "ns_ldap1", o,  data['ldap'][o])

u.set('openvpn', iname, 'ns_user_db', user_db)

if data['rw']['auth'] == 'password':
    u.set("openvpn", iname, 'auth_user_pass_verify', f'{auth_script} via-env')
    u.set("openvpn", iname, 'verify_client_cert', 'none')
    u.set("openvpn", iname, 'username_as_common_name', '1')
    u.set("openvpn", iname, 'script_security', '3')
    u.set("openvpn", iname, 'ns_auth_mode', 'username_password')
elif data['rw']['auth'] == 'password-certificate':
    u.set("openvpn", iname, 'auth_user_pass_verify', f'{auth_script} via-env')
    u.set("openvpn", iname, 'script_security', '3')
    u.set("openvpn", iname, 'ns_auth_mode', 'username_password_certificate')
elif data['rw']['auth'] == 'certificate-otp':
    u.set("openvpn", iname, 'auth_user_pass_verify', '/usr/libexec/ns-openvpn/openvpn-otp-auth via-env')
    u.set("openvpn", iname, 'script_security', '3')
    u.set("openvpn", iname, 'ns_auth_mode', 'username_otp_certificate')


# Create user entries
for user in data["users"]:
    if user['key'] is None or user['crt'] is None:
        nsmigration.vprint(f"Skipping OpenVPN Road Warrior user {user['name']}, missing key or certificate")
        continue
    sname = utils.get_random_id()
    nsmigration.vprint(f"Creating OpenVPN Road Warrior user {user['name']}")
    u.set("users", sname, "user")
    u.set("users", sname, "name", user["name"])
    u.set("users", sname, "database", user_db)
    u.set("users", sname, "openvpn_ipaddr", user["ipaddr"])
    u.set("users", sname, "openvpn_enabled", user["enabled"])
    if 'password' in user and user.get('password'):
        u.set("users", sname, "password", user["password"])
    if '2fa' in user and user.get('2fa'):
        u.set("users", sname, "openvpn_2fa", user["2fa"])
    else:
        u.set("users", sname, "openvpn_2fa", pyotp.random_base32())
    save_cert(os.path.join(cert_dir, f'private/{user["name"]}.key'), user['key'])
    save_cert(os.path.join(cert_dir, f'issued/{user["name"]}.crt'), user['crt'])

# Setup firewall
firewall.add_service(u, 'OpenVPNRW1', data['rw']['options']['port'], ['tcp', 'udp'], link=f"openvpn/{iname}")
firewall.add_trusted_zone(u, "rwopenvpn", link="openvpn/ns_roadwarrior_noremove")
firewall.add_device_to_zone(u, 'tunrw1', 'rwopenvpn')

# Save configuration
u.commit("openvpn")
u.commit("users")
firewall.apply(u)
