// SPDX-License-Identifier: BSD-3-Clause
// rpi-modcopy - Selectively copy kernel modules and their dependencies

#include "modcopy.hpp"

#include <cerrno>
#include <cstring>
#include <fstream>
#include <iostream>
#include <memory>
#include <queue>
#include <system_error>
#include <vector>

#include <libkmod.h>
#include <spawn.h>
#include <sys/wait.h>

namespace fs = std::filesystem;

namespace modcopy {

namespace {

// RAII wrapper for kmod_ctx
struct KmodCtxDeleter {
    void operator()(kmod_ctx* ctx) const {
        if (ctx != nullptr) {
            kmod_unref(ctx);
        }
    }
};
using KmodCtxPtr = std::unique_ptr<kmod_ctx, KmodCtxDeleter>;

// RAII wrapper for kmod_list cleanup
struct KmodListDeleter {
    void operator()(kmod_list* list) const {
        if (list != nullptr) {
            kmod_module_unref_list(list);
        }
    }
};
using KmodListPtr = std::unique_ptr<kmod_list, KmodListDeleter>;

// Filter modules.order to only include modules we're actually copying
// Returns true on success, false on error
bool write_filtered_modules_order(const fs::path& src_modules, const fs::path& dst_modules,
                                  const std::set<fs::path>& files_to_copy, bool verbose) {
    const fs::path src_order = src_modules / "modules.order";
    const fs::path dst_order = dst_modules / "modules.order";

    if (!fs::exists(src_order)) {
        // No modules.order in source - that's fine
        return true;
    }

    std::ifstream in(src_order);
    if (!in) {
        std::cerr << "error: cannot read " << src_order << "\n";
        return false;
    }

    std::ofstream out(dst_order);
    if (!out) {
        std::cerr << "error: cannot write " << dst_order << "\n";
        return false;
    }

    // Build a set of relative paths for quick lookup
    // Normalize by removing compression extensions (.xz, .zst, .gz) since
    // modules.order uses uncompressed names like "kernel/crypto/foo.ko"
    std::set<std::string> copied_modules;
    for (const auto& f : files_to_copy) {
        // Only include actual module files, not metadata
        if (f.extension() == ".ko" || f.string().find(".ko.") != std::string::npos) {
            const fs::path rel = fs::relative(f, src_modules);
            std::string rel_str = rel.string();
            // Strip compression extensions
            for (const char* ext : {".xz", ".zst", ".gz"}) {
                if (rel_str.ends_with(ext)) {
                    rel_str.resize(rel_str.size() - std::strlen(ext));
                    break;
                }
            }
            copied_modules.insert(rel_str);
        }
    }

    std::string line;
    int kept = 0;
    int total = 0;
    while (std::getline(in, line)) {
        total++;
        // modules.order contains relative paths like "kernel/crypto/xts.ko"
        // Trim whitespace
        const size_t start = line.find_first_not_of(" \t");
        if (start == std::string::npos) {
            continue;
        }
        const size_t end = line.find_last_not_of(" \t\r\n");
        const std::string module_path = line.substr(start, end - start + 1);

        if (copied_modules.contains(module_path)) {
            out << module_path << "\n";
            kept++;
        }
    }

    if (verbose) {
        std::cerr << "Filtered modules.order: " << kept << "/" << total << " entries\n";
    }

    return true;
}

}  // namespace

// NOLINTBEGIN(readability-function-cognitive-complexity) ; complex due to kmod API requirements
Result resolve_modules(const Options& opts) {
    Result result;

    // Construct the full path to modules directory
    // e.g., /mnt/kernel-pkg/usr/lib/modules/6.6.20+rpt-rpi-v8
    const fs::path modules_base =
        opts.source / opts.module_dir.relative_path() / opts.kernel_version;

    if (!fs::exists(modules_base)) {
        std::cerr << "error: module directory not found: " << modules_base << "\n";
        result.exit_code = 3;
        return result;
    }

    if (opts.verbose) {
        std::cerr << "Using module directory: " << modules_base << "\n";
    }

    // Create kmod context pointing at the full modules directory (including kernel version)
    // For kmod, we pass an empty config_paths to avoid loading system modprobe configs
    const char* empty_config[] = {
        nullptr};  // NOLINT(modernize-avoid-c-arrays) ; kmod API requires C array
    const KmodCtxPtr ctx{kmod_new(modules_base.c_str(), empty_config)};

    if (ctx == nullptr) {
        std::cerr << "error: failed to create kmod context\n";
        result.exit_code = 3;
        return result;
    }

    // Load indexes for faster lookups
    if (kmod_load_resources(ctx.get()) < 0) {
        std::cerr << "error: failed to load kmod resources\n";
        result.exit_code = 3;
        return result;
    }

    // Track which modules we've already processed (by name)
    std::set<std::string> processed;

    // Queue of modules to process
    std::queue<std::string> to_process;

    // Add all requested modules to the queue
    for (const auto& mod_name : opts.modules) {
        to_process.push(mod_name);
    }

    // Process modules and their dependencies
    while (!to_process.empty()) {
        const std::string mod_name = to_process.front();
        to_process.pop();

        // Skip if already processed
        if (processed.contains(mod_name)) {
            continue;
        }

        if (opts.verbose) {
            std::cerr << "Resolving: " << mod_name << "\n";
        }

        // Look up the module (handles names and aliases)
        kmod_list* list_raw = nullptr;
        const int err = kmod_module_new_from_lookup(ctx.get(), mod_name.c_str(), &list_raw);
        const KmodListPtr list{list_raw};

        if (err < 0) {
            std::cerr << "error: failed to lookup module '" << mod_name << "': "
                      << std::strerror(
                             -err)  // NOLINT(concurrency-mt-unsafe) ; acceptable for error messages
                      << "\n";
            if (opts.keep_going) {
                result.modules_skipped++;
                result.exit_code = 1;
                continue;
            }
            result.exit_code = 3;
            return result;
        }

        if (list == nullptr) {
            std::cerr << "warning: module not found: " << mod_name << "\n";
            if (opts.keep_going) {
                result.modules_skipped++;
                result.exit_code = 1;
                continue;
            }
            result.exit_code = 3;
            return result;
        }

        // Process each module in the list (aliases may resolve to multiple)
        // NOLINTBEGIN(cppcoreguidelines-init-variables) ; kmod_list_foreach macro initialises itr
        kmod_list* itr;
        kmod_list_foreach(itr, list.get()) {
            kmod_module* mod = kmod_module_get_module(itr);
            if (mod == nullptr) {
                continue;
            }

            const char* name = kmod_module_get_name(mod);
            const char* path = kmod_module_get_path(mod);

            // Mark as processed using normalised name
            if (name != nullptr) {
                processed.insert(name);
            }

            // Check if builtin (path is NULL for builtins)
            if (path == nullptr) {
                if (opts.verbose) {
                    std::cerr << "  " << (name != nullptr ? name : mod_name) << " (builtin)\n";
                }
                kmod_module_unref(mod);
                continue;
            }

            if (opts.verbose) {
                std::cerr << "  " << name << " -> " << path << "\n";
            }

            // Add the module file to our set
            result.files_to_copy.insert(path);

            // Get dependencies and add to queue
            kmod_list* deps = kmod_module_get_dependencies(mod);
            if (deps != nullptr) {
                kmod_list* dep_itr;
                kmod_list_foreach(dep_itr, deps) {
                    kmod_module* dep_mod = kmod_module_get_module(dep_itr);
                    if (dep_mod != nullptr) {
                        const char* dep_name = kmod_module_get_name(dep_mod);
                        if (dep_name != nullptr && !processed.contains(dep_name)) {
                            to_process.emplace(dep_name);
                        }
                        kmod_module_unref(dep_mod);
                    }
                }
                kmod_module_unref_list(deps);
            }

            kmod_module_unref(mod);
        }
        // NOLINTEND(cppcoreguidelines-init-variables)
    }

    // Always include essential metadata files (modules.order is handled separately)
    // NOLINTNEXTLINE(readability-identifier-naming) ; local const vector, not a global constant
    const std::vector<std::string> metadata_files = {
        "modules.builtin",
        "modules.builtin.modinfo",
    };

    for (const auto& meta : metadata_files) {
        const fs::path meta_path = modules_base / meta;
        if (fs::exists(meta_path)) {
            result.files_to_copy.insert(meta_path);
        }
    }

    result.modules_copied = static_cast<int>(result.files_to_copy.size());

    if (opts.verbose) {
        std::cerr << "\nResolved " << result.files_to_copy.size() << " files to copy\n";
    }

    return result;
}
// NOLINTEND(readability-function-cognitive-complexity)

Result copy_modules(const Options& opts, const Result& resolved) {
    Result result = resolved;

    if (resolved.files_to_copy.empty()) {
        return result;
    }

    // Source and destination module directories
    const fs::path src_modules =
        opts.source / opts.module_dir.relative_path() / opts.kernel_version;
    const fs::path dst_modules = opts.dest / opts.module_dir.relative_path() / opts.kernel_version;

    if (opts.dry_run) {
        std::cout << "Would copy to: " << dst_modules << "\n";
        for (const auto& file : resolved.files_to_copy) {
            const fs::path rel = fs::relative(file, src_modules);
            std::cout << rel.string() << "\n";
        }
        std::cout << "modules.order (filtered to copied modules only)\n";
        return result;
    }

    // Create destination directory structure
    std::error_code ec;
    fs::create_directories(dst_modules, ec);
    if (ec) {
        std::cerr << "error: failed to create destination directory: " << ec.message() << "\n";
        result.exit_code = 4;
        return result;
    }

    // Copy each file, preserving directory structure
    for (const auto& src_file : resolved.files_to_copy) {
        const fs::path rel = fs::relative(src_file, src_modules);
        const fs::path dst_file = dst_modules / rel;

        // Create parent directories
        fs::create_directories(dst_file.parent_path(), ec);
        if (ec) {
            std::cerr << "error: failed to create directory " << dst_file.parent_path() << ": "
                      << ec.message() << "\n";
            result.exit_code = 4;
            return result;
        }

        // Copy the file
        if (opts.verbose) {
            std::cerr << "Copying: " << rel.string() << "\n";
        }

        fs::copy_file(src_file, dst_file, fs::copy_options::overwrite_existing, ec);
        if (ec) {
            std::cerr << "error: failed to copy " << src_file << ": " << ec.message() << "\n";
            result.exit_code = 4;
            return result;
        }
    }

    // Write filtered modules.order (only includes modules we actually copied)
    if (!write_filtered_modules_order(src_modules, dst_modules, resolved.files_to_copy,
                                      opts.verbose)) {
        result.exit_code = 4;
        return result;
    }

    return result;
}

int run_depmod(const Options& opts) {
    // Build argument list for depmod
    // -b basedir: where to look for modules and write output
    std::vector<std::string> args = {"depmod", "-b", opts.dest.string()};

    // If module_dir is not the default, specify it
    if (opts.module_dir != "/lib/modules") {
        args.emplace_back("-m");
        args.emplace_back(opts.module_dir.string());
    }

    args.push_back(opts.kernel_version);

    if (opts.dry_run || opts.verbose) {
        std::cerr << (opts.dry_run ? "Would run:" : "Running:");
        for (const auto& arg : args) {
            std::cerr << " " << arg;
        }
        std::cerr << "\n";
        if (opts.dry_run) {
            return 0;
        }
    }

    // Convert to C-style argv array for posix_spawn.
    // posix_spawn takes char *const argv[] - the pointers are const but the
    // strings themselves must be mutable, hence args is non-const above.
    std::vector<char*> argv;
    argv.reserve(args.size() + 1);
    // cppcheck-suppress constVariableReference
    for (auto& arg : args) {  // posix_spawn may modify strings (char *const argv[])
        // cppcheck-suppress useStlAlgorithm
        argv.push_back(arg.data());  // std::transform awkward with mutable string requirement
    }
    argv.push_back(nullptr);

    // Spawn depmod process
    pid_t pid = 0;
    const int err = posix_spawn(&pid, "/sbin/depmod", nullptr, nullptr, argv.data(), environ);
    if (err != 0) {
        // NOLINTNEXTLINE(concurrency-mt-unsafe) ; acceptable for error messages
        std::cerr << "error: failed to spawn depmod: " << std::strerror(err) << "\n";
        return 4;
    }

    // Wait for completion (retry on EINTR)
    int status = 0;
    pid_t ret = 0;
    // NOLINTBEGIN(cppcoreguidelines-avoid-do-while) ; standard EINTR retry pattern
    do {
        ret = waitpid(pid, &status, 0);
    } while (ret == -1 && errno == EINTR);
    // NOLINTEND(cppcoreguidelines-avoid-do-while)

    if (ret == -1) {
        // NOLINTNEXTLINE(concurrency-mt-unsafe) ; acceptable for error messages
        std::cerr << "error: waitpid failed: " << std::strerror(errno) << "\n";
        return 4;
    }

    if (WIFEXITED(status)) {
        const int exit_code = WEXITSTATUS(status);
        if (exit_code != 0) {
            std::cerr << "error: depmod failed with exit code " << exit_code << "\n";
            return 4;
        }
    } else if (WIFSIGNALED(status)) {
        std::cerr << "error: depmod killed by signal " << WTERMSIG(status) << "\n";
        return 4;
    }

    return 0;
}

}  // namespace modcopy
