#!/usr/bin/env python3

# pylint: disable=missing-function-docstring
# pylint: disable=missing-module-docstring

import subprocess
import argparse
import logging
import crypt
import os
import sys
import toml

__version__ = "1.0"

imager_custom_path = os.path.join('/', 'usr', 'lib', 'raspberrypi-sys-mods', 'imager_custom')
userconf_path = os.path.join('/', 'usr', 'lib', 'userconf-pi', 'userconf')

def set_hostname (hostname):
    if not hostname:
        return
    cmd = (imager_custom_path, 'set_hostname', hostname)
    try:
        subprocess.run(cmd, encoding='UTF-8', check=True)
        logging.info("Hostname set to %s", hostname)
    except subprocess.CalledProcessError:
        logging.error("Failed to set hostname")

def config_system(system_config):
    set_hostname(system_config.pop("hostname", None))

def rename_user(name, password, is_encrypted=True):
    if not name and not password:
        return
    try:
        for frloc in ('/boot/firstrun.sh', '/boot/firmware/firstrun.sh'):
            with open(frloc, 'r') as f:
                if '/usr/lib/userconf-pi/userconf' in f.read():
                    logging.error ("User configuration skipped: conflicts with firstrun.sh")
                    return
    except FileNotFoundError:
        pass
    if not name:
        logging.error ("User configuration skipped: specified a password without a username")
        return
    if not password:
        logging.warning("%s will be locked. No password specified", name)
        password = "!"
    elif not is_encrypted:
        logging.warning("Plain text user password provided. Encrypting...")
        password = crypt.crypt(password, crypt.mksalt(crypt.METHOD_SHA512))

    try:
        cmd = (userconf_path, name, password)
        subprocess.run(cmd, encoding='UTF-8', check=True)
        logging.info ("User renamed to %s", name)
    except subprocess.CalledProcessError:
        logging.error("Failed to rename user %s", name)

def config_user(user_config):
    name = user_config.pop("name", None)
    password = user_config.pop("password", None)
    is_encrypted = user_config.pop("password_encrypted", True)
    rename_user(name, password, is_encrypted)

def import_ssh_id(ids):
    if not ids:
        return
    cmd = (imager_custom_path, "import_ssh_id")
    if isinstance(ids, str):
        ids = (ids,)
    cmd = cmd + tuple(ids)
    try:
        subprocess.run(cmd, encoding='UTF-8', check=True)
        logging.info ("Importing pubkeys for: %s", ", ".join(ids))
    except subprocess.CalledProcessError:
        logging.error("Failed to import pubkeys for %s", ", ".join(ids))

def config_ssh (ssh_config):
    if not ssh_config:
        return
    ssh_import_id = ssh_config.pop("ssh_import_id", None)
    import_ssh_id(ssh_import_id)
    cmd = (imager_custom_path, "enable_ssh")
    ssh_enabled = ssh_config.pop("enabled", False)
    ssh_password_authentication = ssh_config.pop("password_authentication", None)
    ssh_authorized_keys = ssh_config.pop("authorized_keys", [])
    if not ssh_enabled:
        cmd = cmd + ('-d',)
    if ssh_password_authentication is not None:
        cmd = cmd + ('-p' if ssh_password_authentication else '-k',)
    cmd = cmd + tuple(ssh_authorized_keys)
    try:
        subprocess.run(cmd, encoding='UTF-8', check=True)
        logging.info ("Configured SSH")
    except subprocess.CalledProcessError:
        logging.error("Failed to configure SSH")

def wpa_passphrase (ssid, passphrase):
    cmd = ('wpa_passphrase', ssid)
    proc = subprocess.run(cmd, stdout=subprocess.PIPE, input=passphrase,
                          encoding='UTF-8', check=True)
    output = (line.strip() for line in proc.stdout.splitlines())
    psks = (line.split('=', maxsplit=1)[1] for line in output if line.startswith("psk="))
    return next(psks)

def config_wlan (wlan_config):
    if not wlan_config:
        return
    wlan_ssid = wlan_config.pop("ssid", None)
    wlan_password = wlan_config.pop("password", None)
    wlan_password_is_encrypted = wlan_config.pop("password_encrypted", True)
    wlan_is_hidden = wlan_config.pop("hidden", False)
    wlan_country = wlan_config.pop("country", "")

    cmd = (imager_custom_path, "set_wlan")
    if wlan_is_hidden:
        cmd = cmd + ('-h',)

    if wlan_country:
        try:
            subprocess.run(
                (imager_custom_path, "set_wlan_country", wlan_country),
                encoding='UTF-8', check=True)
            logging.info ("WLAN country set to %s", wlan_country)
        except subprocess.CalledProcessError:
            logging.error ("Failed to set WLAN country")

    if wlan_password and not wlan_ssid:
        logging.error ("WLAN connection skipped: specified WLAN password without an SSID")
        return

    if wlan_ssid:
        if wlan_password and not wlan_password_is_encrypted:
            try:
                wlan_password = wpa_passphrase (wlan_ssid, wlan_password)
            except (subprocess.CalledProcessError, StopIteration):
                cmd = cmd + ('-p',)
                logging.warning ("Could not generate PSK for %s", wlan_ssid)

        cmd = cmd + (wlan_ssid,)
        if wlan_password:
            cmd = cmd + (wlan_password,)

    try:
        subprocess.run(cmd, encoding='UTF-8', check=True)
        logging.info ("Configuring WLAN: %s", wlan_ssid)
    except subprocess.CalledProcessError:
        logging.error ("Failed to configure WLAN")


def set_keymap (keymap):
    if not keymap:
        return
    cmd = (imager_custom_path, 'set_keymap', keymap)
    try:
        subprocess.run(cmd, encoding='UTF-8', check=True)
        logging.info("Keymap set to %s", keymap)
    except subprocess.CalledProcessError:
        logging.error("Failed to set keymap")

def set_timezone (timezone):
    if not timezone:
        return
    cmd = (imager_custom_path, 'set_timezone', timezone)
    try:
        subprocess.run(cmd, encoding='UTF-8', check=True)
        logging.info("Timezone set to %s", timezone)
    except subprocess.CalledProcessError:
        logging.error("Failed to set timezone")

def config_locale (locale_config):
    set_keymap (locale_config.pop("keymap", None))
    set_timezone (locale_config.pop("timezone", None))


parser = argparse.ArgumentParser(
    description="Read and apply a custom configuration to Raspberry Pi OS from a TOML file"
)
parser.add_argument(
    "toml_file", nargs="?", default="/dev/stdin",
    help="Path to a TOML file containing initial configuration (defaults to stdin)",
)
parser.add_argument('--version', action='version', version=f"%(prog)s {__version__}")
args = parser.parse_args()

logging.basicConfig(format="%(message)s", level=logging.INFO)

try:
    config = toml.load(args.toml_file)
except toml.decoder.TomlDecodeError as err:
    logging.error ("Error parsing %s: %s", args.toml_file, err)
    sys.exit(1)

config_version = config.pop("config_version", None)
if config_version != 1:
    logging.error ("Unsupported config version")

supported_sections = ("system", "user", "ssh", "wlan", "locale")
for s in supported_sections:
    section_config = config.pop(s, {})
    locals()[f"config_{s}"](section_config)
    for key in section_config:
        logging.warning("Unknown key in [%s]: %s", s, key)

for key in config:
    logging.warning("Unknown section or key: %s", key)
