#!/usr/bin/python

#
# Copyright (C) 2024 Nethesis S.r.l.
# SPDX-License-Identifier: GPL-2.0-only
#

#
# This script generates nftables rules for netmap
# It reads the netmap configuration from UCI /etc/config/netmap and generates the nftables rules
# for source and destination NAT
#

import os
import subprocess
import sys
import ipaddress
from euci import EUci
from nethsec import utils

SRC_FILE="/usr/share/nftables.d/chain-pre/srcnat/20netmap.nft"
DST_FILE="/usr/share/nftables.d/chain-post/dstnat/20netmap.nft"

def setup():
    if not os.path.exists(os.path.dirname(SRC_FILE)):
        os.makedirs(os.path.dirname(SRC_FILE))
    if not os.path.exists(os.path.dirname(DST_FILE)):
        os.makedirs(os.path.dirname(DST_FILE))

def cleanup():
    if os.path.exists(SRC_FILE):
        os.remove(SRC_FILE)
    if os.path.exists(DST_FILE):
        os.remove(DST_FILE)

def generate_rules():
    e_uci = EUci()
    rules = utils.get_all_by_type(e_uci, "netmap", "rule")
    sf = open(SRC_FILE, "w")
    df = open(DST_FILE, "w")
    
    k = 0
    for r in rules:
        nft = ''
        rule = rules[r]
        name = rule.get("name", f"netmap{k}")
        src = rule.get("src")
        dest = rule.get("dest")
        # convert empty tuple to None
        if rule.get("device_in") != ('',) and rule.get("device_in") != None:
            device_in = list(rule.get("device_in"))
        else:
            device_in = None

        if rule.get("device_out") != ('',) and rule.get("device_out") != None:
            device_out = list(rule.get("device_out"))
        else:
            device_out = None

        map_from = rule.get("map_from")
        map_to = rule.get("map_to")
        try:
            ipaddress.ip_network(map_from)
            ipaddress.ip_network(map_to)
        except ValueError as e:
            continue
        if device_in:
            nft = "iifname {\"" + "\",\"".join(device_in) + "\"} "
        if device_out:
            nft = nft + "oifname {\"" + "\",\"".join(device_out) + "\"} "
        if dest:
            nft = nft + f"ip daddr {dest} snat ip prefix to ip saddr map {{ {map_from} : {map_to} }} comment \"!fw4: {name}\"\n"
            sf.write(nft)
        elif src:
            nft = nft + f"ip saddr {src} dnat ip prefix to ip daddr map {{ {map_from} : {map_to} }} comment \"!fw4: {name}\"\n"
            df.write(nft)
        else:
            continue
        k += 1
    sf.close()
    df.close()

def main():
    if len(sys.argv) > 1 and sys.argv[1] == "cleanup":
        cleanup()
    else:
        setup()
        cleanup()
        generate_rules()
    subprocess.run(['/etc/init.d/firewall', 'reload'])

if __name__ == "__main__":
    main()
