#!/usr/bin/python3

#
# Copyright (C) 2024 Nethesis S.r.l.
# SPDX-License-Identifier: GPL-2.0-only
#

import argparse
import os
import sys
import glob
import shutil
import logging
import urllib.request
import tarfile
import subprocess
from nethsec import snort
from euci import EUci

# Use persistent storage when /mnt/data is available so rule archives survive reboots
DATA_DIR = "/mnt/data/ns-snort" if os.path.isdir("/mnt/data") else "/var/ns-snort"
RULES_DIR = os.path.join(DATA_DIR, "rules")
BACKUP_DIR = os.path.join(DATA_DIR, "old.rules")
TESTING_RULES_FILE = os.path.join(RULES_DIR, "testing.rules")
OFFICIAL_RULES_DIR = os.path.join(DATA_DIR, "snort-rules")
COMMUNITY_RULES_URL = "https://www.snort.org/downloads/community/snort3-community-rules.tar.gz"

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(message)s')

def log(message):
    logging.info(message)

def backup_rules():
    # remove old backup
    shutil.rmtree(BACKUP_DIR, ignore_errors=True)
    if os.path.exists(RULES_DIR):
        shutil.copytree(RULES_DIR, BACKUP_DIR)
        log(f"Backup created at {BACKUP_DIR}")

def generate_testing_rules():
    shutil.rmtree(RULES_DIR, ignore_errors=True)
    os.makedirs(RULES_DIR, exist_ok=True)
    rules_file = TESTING_RULES_FILE
    testing_rules = [
        'alert icmp any any <> any any (msg:"TEST ALERT ICMP v4"; icode:0; itype:8; sid:99010;)',
        'alert icmp any any <> any any (msg:"TEST ALERT ICMP v6"; icode:0; itype:33; sid:99011;)',
        'alert icmp any any <> any any (msg:"TEST ALERT ICMP v6"; icode:0; itype:34; sid:99012;)'
    ]
    with open(rules_file, 'w') as f:
        for rule in testing_rules:
            f.write(rule + "\n")

def download_official_rules(oinkcode):
    if oinkcode:
        log("Downloading subscriber rules...")
        # remember to change the URL even in the `ns.snort` script
        url = f"https://www.snort.org/rules/snortrules-snapshot-31470.tar.gz?oinkcode={oinkcode}"
        archive_loc = f"snortrules-snapshot-{oinkcode}"
    else:
        log("Downloading community rules...")
        url = COMMUNITY_RULES_URL
        archive_loc = "snort3-community-rules"
    
    os.makedirs(OFFICIAL_RULES_DIR, exist_ok=True)
    archive_path = os.path.join(DATA_DIR, f"{archive_loc}.tar.gz")
    
    try:
        with urllib.request.urlopen(url, timeout=10) as response:
            if response.status != 200:
                log(f"Failed to download rules from {url}")
                sys.exit(1)
            with open(archive_path, 'wb') as f:
                shutil.copyfileobj(response, f)
        log(f"Downloaded rules to {archive_path}")
        
        with tarfile.open(archive_path, 'r:gz') as tar:
            tar.extractall(path=OFFICIAL_RULES_DIR)
        log(f"Extracted rules to {OFFICIAL_RULES_DIR}")
    except Exception as e:
        log(f"Failed to download rules from {url}: {e}")
        sys.exit(1)

def filter_official_rules(policy, alert_excluded, disabled_rules, oinkcode=None):
    drop_rules = set()
    alert_rules = set()
    ret = set()
    disabled_sids = set()
    rule_files = []

    # A disabled rule is a list of 3 elements separated by commas:
    # <gid>,<sid>,<description>
    # the description is optional, the following chars are not allowed: , \n and space
    for d in disabled_rules:
        parts = d.split(",")
        disabled_sids.add(f'{parts[0]}:{parts[1]}')

    if oinkcode:
        rule_files = glob.glob(os.path.join(OFFICIAL_RULES_DIR, "rules/*.rules"))
    else:
        # official rules are downloaded to /var/snort.d/snort3-community-rules/snort3-community.rules
        rule_files = [os.path.join(OFFICIAL_RULES_DIR, "snort3-community-rules/snort3-community.rules")]

    for file in rule_files:
        if not os.path.exists(file):
            continue
        log(f"Processing rule file {file}")
        for rule in snort.parse_file(file):
            # rule.metadata can contain:
            # - policy balanced-ips drop
            # - policy connectivity-ips drop
            # - policy security-ips drop
            # - policy max-detect-ips drop
            if rule.metadata is None or not rule.enabled or f'{rule.gid}:{rule.sid}' in disabled_sids:
                log(f"Skipping disabled rule {rule.gid}:{rule.sid}")
                continue
            if policy == "connectivity" and 'policy connectivity-ips drop' in rule.metadata:
                drop_rules.add(rule)
            elif policy == "balanced" and 'policy balanced-ips drop' in rule.metadata:
                drop_rules.add(rule)
            elif policy == "security" and 'policy security-ips drop' in rule.metadata:
                drop_rules.add(rule)
            elif policy == "max-detect" and 'policy max-detect-ips drop' in rule.metadata:
                drop_rules.add(rule)
            else:
                # Add excluded rules as alerts but not if they are ICMP (too noisy)
                if alert_excluded and 'PROTOCOL-ICMP' not in rule.msg:
                    alert_rules.add(rule)
        for rule in drop_rules:
            if rule.action != 'block':
                rule.raw = rule.raw.replace(rule.action, "block", 1)
            ret.add(rule.raw)
    # append alert rules
    ret.update(alert_rules)
    return list(ret)

def prepare_rule_file():
    rule_file = os.path.join(RULES_DIR, "snort.rules")
    if os.path.exists(rule_file):
        os.remove(rule_file)
    return rule_file

def append_rules_to_file(rules):
    # create dir if not exists
    os.makedirs(RULES_DIR, exist_ok=True)
    rule_file = os.path.join(RULES_DIR, "snort.rules")
    # create file if not exists
    if not os.path.exists(rule_file):
        with open(rule_file, 'w') as f:
            f.write("# This file is automatically generated by ns-snort-rules\n")
    for rule in rules:
        with open(rule_file, 'a') as f:
            f.write(str(rule) + "\n")
    log(f"Appended {len(rules)} rules to {rule_file}")

def main():
    uci = EUci()
    parser = argparse.ArgumentParser(description='Process snort rules.')

    # General options
    parser.add_argument('--download', action='store_true', help="Download Snort rules.")
    parser.add_argument('--restart', action='store_true', help="Force Snort to restart after updating rules.")
    args = parser.parse_args()

    backup_rules()

    prepare_rule_file() # /etc/snort/rules/snort.rules

    try:
        disabled_rules = list(uci.get_all("snort", "snort", "ns_disabled_rules"))
    except:
        disabled_rules = []

    oinkcode = uci.get("snort", "snort", "oinkcode", default=None)
    if args.download:
        download_official_rules(oinkcode)
    
    official_policy = uci.get("snort", "snort", "ns_policy", default="security")
    alert_excluded = uci.get("snort", "snort", "ns_alert_excluded", default="")
    rules = filter_official_rules(official_policy, alert_excluded, disabled_rules, oinkcode)

    append_rules_to_file(rules)

    if uci.get("snort", "snort", "ns_testing", default=False, dtype=bool):
        log("Adding testing rules...")
        generate_testing_rules()
    else:
        log("Removing testing rules...")
        if os.path.exists(TESTING_RULES_FILE):
            os.remove(TESTING_RULES_FILE)

    if args.restart:
        log("Restarting snort...")
        subprocess.run(["/etc/init.d/snort", "restart"], check=True)

if __name__ == "__main__":
    main()
