#!/usr/bin/python3

#
# Copyright (C) 2026 Nethesis S.r.l.
# SPDX-License-Identifier: GPL-2.0-only
#

import os
import sys
import sqlite3
from euci import EUci
from nethsec import users

def get_db_path(u, instance):
    base_path = u.get('fstab', 'ns_data', 'target', default='')
    if not os.path.isdir(base_path):
        base_path = '/var'
    return os.path.join(base_path, 'openvpn', instance, 'connections.db')

uci = EUci()
conn = sqlite3.connect(get_db_path(uci, sys.argv[1]))
c = conn.cursor()

env = os.environ
common_name = env.get('common_name')
if not common_name:
    sys.exit(0)
start_time = int(env.get('time_unix', '0'))
duration = int(env.get('time_duration', '0'))
bytes_received = int(env.get('bytes_received', '0'))
bytes_sent = int(env.get('bytes_sent', '0'))

# Update connection data
c.execute("UPDATE connections SET duration=?, bytes_received=?, bytes_sent=? WHERE common_name=? and start_time=?", (duration, bytes_received, bytes_sent, common_name, start_time))

if c.rowcount == 0:
    # disconnection of a client whose connection record has been saved on storage that is not available now
    # insert a new full record with all the data on /var, but only if the user is currently enabled
    # (avoid creating entries for disabled users that still generate disconnect events)
    enabled = False
    try:
        db = uci.get('openvpn', sys.argv[1], 'ns_user_db', default=None)
        user = users.get_user_by_name(uci, common_name, db) if db else None
        enabled = user is not None and user.get('openvpn_enabled', '0') == '1'
    except:
        # on any error while checking UCI/users, assume disabled to be safe
        pass

    if enabled:
        virtual_ip_addr = env.get('ifconfig_pool_remote_ip', '')
        remote_ip_addr = env.get('untrusted_ip', '')
        c.execute(
            "INSERT INTO connections (common_name, virtual_ip_addr, remote_ip_addr, start_time, duration, bytes_received, bytes_sent) "
            "VALUES (?, ?, ?, ?, ?, ?, ?)",
            (common_name, virtual_ip_addr, remote_ip_addr, start_time, duration, bytes_received, bytes_sent)
        )

conn.commit()
conn.close()
sys.exit(0)
