# Copyright 2021 Offensive Security
# SPDX-license-identifier: GPL-3.0-only

import glob
import os
import re
import shutil
import tempfile
from collections import namedtuple
from textwrap import indent
from urllib.parse import urlparse, urlunparse

from kali_tweaks.utils import (
    apt_update,
    logger,
    run_as_root,
    write_file_as_root,
)

KNOWN_TYPES = ["deb", "deb-src"]

KALI_DEFAULT_MIRROR = "http.kali.org"
KALI_DEFAULT_PROTOCOL = "http"

KALI_SUPPORTED_MIRRORS = [
    "http.kali.org",
    "kali.download",
]

KALI_SUPPORTED_PROTOCOLS = [
    "http",
    "https",
]

# All the known Kali suites
KALI_SUITES = [
    "kali-bleeding-edge",
    "kali-dev",
    "kali-experimental",
    "kali-last-snapshot",
    "kali-rolling",
]

# The suites that can be enable/disabled in kali-tweaks
KALI_EXTRA_SUITES = [
    "kali-bleeding-edge",
    "kali-experimental",
]

# NB: keep it sorted this way
KALI_COMPONENTS = [
    "main",
    "contrib",
    "non-free",
    "non-free-firmware",
]

KALI_KEYRING = "/usr/share/keyrings/kali-archive-keyring.gpg"

# Helpers for both Deb822-Style and One-Line-Style ----------------------------


def consolidate_repo_settings(schemes, netlocs, suites):
    """
    Given lists of schemes, netlocs, and suites, return:
    * a protocol if only one scheme and it's supported, otherwise None
    * a mirror if only one netloc and it's supported, otherwise None
    * a dict with the extra suites, and whether they're enabled
    """

    schemes = list(set(schemes))
    netlocs = list(set(netlocs))
    suites = list(set(suites))

    proto = schemes[0] if len(schemes) == 1 else None
    if proto and proto not in KALI_SUPPORTED_PROTOCOLS:
        proto = None

    mirror = netlocs[0] if len(netlocs) == 1 else None
    if mirror and mirror not in KALI_SUPPORTED_MIRRORS:
        mirror = None

    extra_suites = {s: False for s in KALI_EXTRA_SUITES}
    for suite in suites:
        if suite in KALI_EXTRA_SUITES:
            extra_suites[suite] = True

    return proto, mirror, extra_suites


# One-Line-Style Sources ------------------------------------------------------

OneLineSource = namedtuple("OneLineSource", "type, options, uri, suite, components")


def mk_one_line_source(deb_type, options, uri, suite, components):
    """
    Create a OneLineSource out of string input parameters.

    Raise ValueError if uri parsing fails.
    """

    options = options.split()
    components = components.split()
    uri = urlparse(uri)

    return OneLineSource(deb_type, options, uri, suite, components)


def parse_one_line_style_line(line):
    """
    Parse a line from a sources.list file (one-line-style format),
    return a named tuple.

    If the line is malformed, raise ValueError.

    For reference regarding the format of sources.list file:
    https://manpages.debian.org/unstable/apt/sources.list.5.en.html
    """

    line = line.strip()

    if not line or line.startswith("#"):
        return None

    deb_type, line = line.split(maxsplit=1)
    if deb_type not in KNOWN_TYPES:
        raise ValueError

    options = ""
    if line.startswith("["):
        idx = line.index("]")
        options = line[1:idx]
        line = line[idx + 1 :].lstrip()

    uri, suite, components = line.split(maxsplit=2)

    return mk_one_line_source(deb_type, options, uri, suite, components)


def print_one_line_style_line(source):
    """
    Print a source in the one-line-style format, return a string.

    This line is meant to be added to a sources.list file.
    """

    s = source
    parts = []

    parts.append(s.type)
    if s.options:
        options = " ".join(s.options)
        parts.append(f"[ {options} ]")
    parts.append(urlunparse(s.uri))
    parts.append(s.suite)
    parts.append(" ".join(s.components))

    return " ".join(parts)


def print_one_line_style_default_line(proto, mirror, suite):
    """
    Print the default Kali source in the one-line-style format,
    return a string.

    This line is meant to be added to a sources.list file.
    """

    components = " ".join(KALI_COMPONENTS)
    line = f"deb {proto}://{mirror}/kali {suite} {components}"
    return line


def parse_one_line_style_sources(content):
    """
    Parse the content from a sources.list file, return the Kali sources
    found, ignore the deb-src lines.

    What's the criteria to identify a "Kali source"? This is based on
    the suite name: if it's a known Kali suite (eg. kali-rolling,
    kali-dev, etc...) then it's a Kali source. Obviously this is not
    perfect: if ever a third-party provides packages for Kali via an
    APT repository, and uses a known Kali suite name, then it will be
    recognized a Kali source.
    """

    sources = []

    for line in content.splitlines():
        try:
            source = parse_one_line_style_line(line)
        except ValueError:
            continue
        if not source:
            continue
        if source.type != "deb":
            continue
        if source.suite not in KALI_SUITES:
            continue
        sources.append(source)

    return sources


def analyze_one_line_style_sources(sources):
    """
    Analyze a list of one-line-style sources, and return the protocol,
    mirror and extra suites that are configured.
    """

    schemes = [s.uri.scheme for s in sources]
    netlocs = [s.uri.netloc for s in sources]
    suites = [s.suite for s in sources]

    proto, mirror, extra_suites = consolidate_repo_settings(schemes, netlocs, suites)
    return proto, mirror, extra_suites


def update_one_line_style_sources(
    content, protocol=None, mirror=None, remove_suites=[]
):
    """
    Parse the content from a sources.list file, and modify it according to
    the arguments. Modify Kali suites only.

    Return a tuple with a modified content, and a boolean that says whether
    the content was modified or not.
    """

    modified = False
    output = []

    for line in content.splitlines():
        try:
            source = parse_one_line_style_line(line)
        except ValueError:
            output.append(line)
            continue

        if not source:
            output.append(line)
            continue

        if source.suite not in KALI_SUITES:
            output.append(line)
            continue

        if source.suite in remove_suites:
            modified = True
            continue

        uri = source.uri

        if protocol and protocol != uri.scheme:
            uri = uri._replace(scheme=protocol)

        if mirror and mirror != uri.netloc:
            uri = uri._replace(
                netloc=mirror, path="kali", params="", query="", fragment=""
            )

        if uri != source.uri:
            source = source._replace(uri=uri)
            line = print_one_line_style_line(source)
            modified = True

        output.append(line)

    if modified:
        content = "\n".join(output) + "\n" if output else ""

    return content, modified


# Deb822-Style Sources --------------------------------------------------------

Deb822Source = namedtuple("Deb822Source", "types, uris, suites, components, options")


def mk_deb822_source(types, uris, suites, components, options):
    """
    Create a Deb822Source out of string input parameters.

    Raise ValueError if uri parsing fails.
    """

    types = types.split()
    uris = [urlparse(u) for u in uris.split()]
    suites = suites.split()
    components = components.split()
    options = dict(options)  # make a copy

    return Deb822Source(types, uris, suites, components, options)


def parse_deb822_style_stanza(stanza):
    """
    Parse a deb822 stanza.

    If it's a Kali source, return a named tuple. If it is not a Kali source,
    return None. If the stanza is malformed, or mandatory fields are missing,
    raise ValueError.

    For reference regarding the format of sources.list file:
    https://manpages.debian.org/unstable/apt/sources.list.5.en.html
    """

    # Parse the stanza
    data = {}
    key = None
    for line in stanza.splitlines():
        # Ignore comments
        if line.startswith("#"):
            continue
        # Handle multiline values
        if line.startswith(" "):
            if not key:
                raise ValueError
            if key in data:
                data[key] += "\n" + line.strip()
            else:
                data[key] = line.strip()
        # Handle key value pair
        else:
            key, sep, value = line.partition(":")
            if not sep:
                raise ValueError
            key = key.lower()
            value = value.strip()
            if value:
                data[key] = value

    # Ignore empty stanzas
    if not data:
        return None

    # Ignore stanzas that are not enabled
    if data.get("enabled", "yes") == "no":
        return None

    # Mandatory fields that we're interested in
    try:
        types = data.pop("types")
        uris = data.pop("uris")
        suites = data.pop("suites")
        components = data.pop("components")
    except KeyError:
        raise ValueError

    # Filter out non-Kali suites, return None if
    # there's no Kali suite left after that
    kali_suites = [s for s in suites.split() if s in KALI_SUITES]
    if not kali_suites:
        return None
    suites = " ".join(kali_suites)

    return mk_deb822_source(types, uris, suites, components, data)


def print_deb822_style_stanza(source):
    """
    Print a source in the deb822-style format, return a multiline string.

    This multiline is meant to be added to a .sources file.
    """

    s = source

    # Mandatory values
    types = " ".join(s.types)
    uris = " ".join([urlunparse(u) for u in s.uris])
    suites = " ".join(s.suites)
    components = " ".join(s.components)
    lines = [
        f"Types: {types}",
        f"URIs: {uris}",
        f"Suites: {suites}",
        f"Components: {components}",
    ]

    # Optional values
    def print_option(key, value):
        key = key.title()
        value = value.rstrip("\n")
        if value.count("\n") == 0:
            text = f"{key}: {value}"
        else:
            text = f"{key}:\n"
            text += indent(value, " ")
        return text

    for key, value in s.options.items():
        lines.append(print_option(key, value))

    # Add Signed-By if missing
    if "signed-by" not in s.options:
        lines.append(print_option("signed-by", KALI_KEYRING))

    return "\n".join(lines) + "\n"


def print_deb822_style_default_stanza(proto, mirror, extra_suites):
    """
    Print the default Kali source in the deb822-style format,
    return a multiline string.

    This multiline is meant to be added to a .sources file.
    """

    suites = " ".join(["kali-rolling"] + extra_suites)
    components = " ".join(KALI_COMPONENTS)

    lines = [
        "Types: deb",
        f"URIs: {proto}://{mirror}/kali/",
        f"Suites: {suites}",
        f"Components: {components}",
        f"Signed-By: {KALI_KEYRING}",
    ]

    return "\n".join(lines) + "\n"


def parse_deb822_style_sources(content):
    """
    Parse the content of a deb822 sources file, return the Kali sources found.
    """

    # Split content in stanzas
    blocks = re.split(r"\r?\n\s*\r?\n", content)

    # Fixup for the non-perfect regex above
    stanzas = []
    for b in blocks:
        if not b:
            continue
        b = b.rstrip("\n")
        stanzas.append(b)

    # Iterate on stanzas
    sources = []
    for stanza in stanzas:
        try:
            source = parse_deb822_style_stanza(stanza)
        except ValueError:
            continue
        if source:
            sources.append(source)

    return sources


def analyze_deb822_style_sources(sources):
    """
    Analyze a list of deb822-style sources, and return the protocol,
    mirror and extra suites that are configured.
    """

    schemes = []
    netlocs = []
    suites = []
    for s in sources:
        schemes += [uri.scheme for uri in s.uris]
        netlocs += [uri.netloc for uri in s.uris]
        suites += s.suites

    proto, mirror, extra_suites = consolidate_repo_settings(schemes, netlocs, suites)
    return proto, mirror, extra_suites


def update_deb822_style_sources(content, protocol=None, mirror=None, extra_suites={}):
    """
    Parse the content from a .sources file, and modify it according to
    the arguments. If we find a stanza that matches Kali, modify it and
    return only this stanza. If we don't find a Kali stanza, return a
    the default Kali stanza.

    Return a tuple with a modified content, and a boolean that says whether
    the content was modified or not.
    """

    modified = False

    # Parse
    sources = parse_deb822_style_sources(content)

    # Pick the stanza that matches a Kali suite,
    # in order of preference.
    source = None
    for suite in ["kali-rolling", "kali-last-snapshot", "kali-dev"]:
        for s in sources:
            if suite in s.suites:
                source = s
                break
        if source:
            break

    # If no source, create one from scratch
    if source is None:
        protocol = protocol or KALI_DEFAULT_PROTOCOL
        mirror = mirror or KALI_DEFAULT_MIRROR
        content = print_deb822_style_default_stanza(mirror, protocol, extra_suites)
        return content, True

    # Edit the URIs
    for idx, uri in enumerate(source.uris):
        orig_uri = uri

        if protocol and protocol != uri.scheme:
            uri = uri._replace(scheme=protocol)

        if mirror and mirror != uri.netloc:
            uri = uri._replace(
                netloc=mirror, path="kali/", params="", query="", fragment=""
            )

        if uri != orig_uri:
            source.uris[idx] = uri
            modified = True

    # Remove extra suites that are not enabled
    for suite in KALI_EXTRA_SUITES:
        if suite not in extra_suites:
            if suite in source.suites:
                source.suites.remove(suite)
                modified = True

    # Append extra suites that are enabled and missing
    for suite in extra_suites:
        if suite not in source.suites:
            source.suites.append(suite)
            modified = True

    # Write it out
    if modified:
        content = print_deb822_style_stanza(source)

    return content, modified


# Classes ---------------------------------------------------------------------


class TempBackup:
    """
    Class to backup files in /tmp, then either restore those backups
    all at once, or delete them all.
    """

    def __init__(self):
        self.backups = []

    def backup(self, file):
        """
        Create a backup of a file as a temporary file.
        """
        # We can't backup a file that doesn't exist, however we can
        # still track it, so that we can remove it during restore
        if not os.path.exists(file):
            self.backups += [(file, None)]
            logger.debug("Backed up %s (non-existing file)", file)
            return
        fd, temp_path = tempfile.mkstemp()
        os.close(fd)
        shutil.copy2(file, temp_path)
        self.backups += [(file, temp_path)]
        logger.debug("Backed up %s -> %s", file, temp_path)

    def restore_all(self):
        """
        Restore all backups.
        """
        for orig, bck in self.backups:
            if bck is None:
                try:
                    os.remove(orig)
                except PermissionError:
                    cmd = f"rm -f {orig}"
                    res = run_as_root(cmd, interactive=True, log=False)
                    if res.returncode != 0:
                        print(f"Couldn't remove file '{orig}'!")
                        print("Please remove it manually.")
            else:
                try:
                    shutil.copy2(bck, orig)
                except PermissionError:
                    cmd = f"cp --preserve=all {bck} {orig}"
                    res = run_as_root(cmd, interactive=True, log=False)
                    if res.returncode != 0:
                        print(f"Couldn't restore file '{orig}'!")
                        print(f"Please restore it manually from: '{bck}'.")
                        continue
                os.remove(bck)
        self.backups = []

    def remove_all(self):
        """
        Remove all backup files.
        """
        for _, bck in self.backups:
            if bck is None:
                continue
            os.remove(bck)
        self.backups = []


class AptRepositoriesSetting:
    def __init__(self, update_apt=True):
        self.update_apt = update_apt
        self.mirror = None
        self.protocol = None
        self.extra_suites = {}
        self.has_deb822_sources = False

    def _get_deb822_source_files(self):
        source_files = [
            "/etc/apt/sources.list.d/kali.sources",
            "/etc/apt/sources.list.d/moved-from-main.sources",
        ]
        for suite in KALI_EXTRA_SUITES:
            fn = f"/etc/apt/sources.list.d/{suite}.sources"
            source_files.append(fn)
        return source_files

    def _get_one_line_style_sources(self):
        source_files = ["/etc/apt/sources.list"]
        source_files += glob.glob("/etc/apt/sources.list.d/*.list")
        return source_files

    def _load(self):
        """
        Load the various Kali APT sources files, and find out how Kali
        repositories are configured.
        """
        deb822_sources = []
        one_line_sources = []

        # Parse deb822-style sources
        source_files = self._get_deb822_source_files()
        for fn in source_files:
            try:
                with open(fn) as f:
                    content = f.read()
            except (FileNotFoundError, PermissionError):
                continue
            deb822_sources += parse_deb822_style_sources(content)

        # Parse one-line-style sources
        source_files = self._get_one_line_style_sources()
        for fn in source_files:
            try:
                with open(fn) as f:
                    content = f.read()
            except (FileNotFoundError, PermissionError):
                continue
            one_line_sources += parse_one_line_style_sources(content)

        # Analyze and consolidate the results, give precedence to deb822
        if one_line_sources:
            config = analyze_one_line_style_sources(one_line_sources)
            self.protocol = config[0]
            self.mirror = config[1]
            self.extra_suites = config[2]
        if deb822_sources:
            config = analyze_deb822_style_sources(deb822_sources)
            self.protocol = config[0]
            self.mirror = config[1]
            self.extra_suites = config[2]
            self.has_deb822_sources = True

    def load(self):
        """
        Load the various sources.list files.

        Return a dict showing how Kali is configured, to the best
        of our knowledge.
        """
        self._load()

        return {
            "extra-repos": dict(self.extra_suites),
            "mirror": self.mirror,
            "protocol": self.protocol,
        }

    def save(self, config):
        """
        Save changes.

        The input parameter 'config' is a dictionary, and it contains
        only the settings that have been changed by user.
        """

        self._load()

        # Turn input into variables that we're going to use
        new_proto = config.get("protocol", None)
        new_mirror = config.get("mirror", None)
        extra_repos = config.get("extra-repos", {})

        # A special case: if caller wants to set the protocol to HTTPS,
        # and the mirror is neither http.kali.org nor kali.download,
        # then we force the mirror to a default value. This is because
        # we can't just set HTTPS for an unknown mirror, as we have no
        # idea if it supports HTTPS.
        if new_proto and new_proto == "https":
            if not new_mirror and not self.mirror:
                new_mirror = KALI_DEFAULT_MIRROR

        # We're going to backup every file that we touch
        backups = TempBackup()

        #
        # Update deb822-stlye sources - everything goes to kali.sources,
        # and other .sources and .list files that are known to belong to
        # Kali are removed.
        #
        if self.has_deb822_sources:
            # We have a list of extra suites for which the value (on/off)
            # was changed by user. We want to turn that into the definitive
            # list of extra suites that are enabled, based on user changes,
            # but also based on what's currently enabled and didn't change.
            extra_suites_enabled = []
            for suite in KALI_EXTRA_SUITES:
                if suite in extra_repos:
                    if extra_repos[suite] is True:
                        extra_suites_enabled.append(suite)
                elif suite in self.extra_suites:
                    if self.extra_suites[suite] is True:
                        extra_suites_enabled.append(suite)

            # Read the file of interest
            source_files = [
                "/etc/apt/sources.list.d/kali.sources",
                "/etc/apt/sources.list.d/moved-from-main.sources",
            ]
            content = None
            for fn in source_files:
                try:
                    with open(fn) as f:
                        content = f.read()
                        break
                except (FileNotFoundError, PermissionError):
                    continue

            # Update or create
            if content:
                content, modified = update_deb822_style_sources(
                    content,
                    protocol=new_proto,
                    mirror=new_mirror,
                    extra_suites=extra_suites_enabled,
                )
            else:
                content = print_deb822_style_default_stanza(
                    new_proto or self.protocol or KALI_DEFAULT_PROTOCOL,
                    new_mirror or self.mirror or KALI_DEFAULT_MIRROR,
                    extra_suites_enabled,
                )
                modified = True

            # No modification?
            if not modified:
                return

            # Write changes to kali.sources
            fn = "/etc/apt/sources.list.d/kali.sources"
            backups.backup(fn)
            try:
                print(f"> Writing changes to {fn}")
                write_file_as_root(fn, content)
            except RuntimeError:
                print("> Error, restoring original APT sources files")
                backups.restore_all()
                raise

            # Remove other files if ever they exist
            source_files = []
            for suite in KALI_EXTRA_SUITES:
                fn = f"/etc/apt/sources.list.d/{suite}.sources"
                source_files.append(fn)
            source_files.append("/etc/apt/sources.list.d/moved-from-main.sources")
            source_files.append("/etc/apt/sources.list")
            for suite in KALI_EXTRA_SUITES:
                fn = f"/etc/apt/sources.list.d/{suite}.list"
                source_files.append(fn)
            for fn in source_files:
                if not os.path.exists(fn):
                    continue
                backups.backup(fn)
                try:
                    print(f"> Removing file {fn}")
                    run_as_root(f"rm {fn}", log=False)
                except RuntimeError:
                    print("> Error, restoring original APT sources files")
                    backups.restore_all()
                    raise

        #
        # Update one-line-style sources otherwise
        #
        else:
            # Transform extra_repos into two lists
            suites_to_add = []
            suites_to_remove = []
            for repo, enabled in extra_repos.items():
                if enabled is True:
                    suites_to_add.append(repo)
                else:
                    suites_to_remove.append(repo)

            # Apply the updates, which means:
            # - update mirror and protocol if requested
            # - remove repositories if requested
            source_files = self._get_one_line_style_sources()
            for fn in source_files:
                # Read file
                try:
                    with open(fn) as f:
                        content = f.read()
                except (FileNotFoundError, PermissionError):
                    continue

                # Update content
                content, modified = update_one_line_style_sources(
                    content,
                    protocol=new_proto,
                    mirror=new_mirror,
                    remove_suites=suites_to_remove,
                )
                if modified is False:
                    continue

                # Check if new content is actually empty
                content_empty = True
                for line in content:
                    if line.strip() != "":
                        content_empty = False
                        break
                if content_empty:
                    content = None

                # Write changes
                backups.backup(fn)
                try:
                    if content:
                        print(f"> Writing changes to {fn}")
                        write_file_as_root(fn, content)
                    else:
                        print(f"> Removing file {fn}")
                        run_as_root(f"rm {fn}", log=False)
                except RuntimeError:
                    print("> Error, restoring original sources.list files")
                    backups.restore_all()
                    raise

            # Add repositories
            for suite in suites_to_add:
                content = print_one_line_style_default_line(
                    new_proto or self.protocol or KALI_DEFAULT_PROTOCOL,
                    new_mirror or self.mirror or KALI_DEFAULT_MIRROR,
                    suite,
                )
                target = f"/etc/apt/sources.list.d/{suite}.list"
                print(f"> New repository: {content}")
                print(f"> Installing to: '{target}'")
                try:
                    write_file_as_root(target, content)
                except RuntimeError:
                    print("> Error, restoring original sources.list files")
                    backups.restore_all()
                    raise

            # No file was modified?
            if not backups.backups and not suites_to_add:
                return

        # Run apt update
        if self.update_apt:
            try:
                apt_update(force=True)
            except RuntimeError:
                print("> Error, restoring original APT sources files")
                backups.restore_all()
                raise

        # Everything went fine, remove backup files
        backups.remove_all()
