// Copyright 2012 Google Inc.
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
//   notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
//   notice, this list of conditions and the following disclaimer in the
//   documentation and/or other materials provided with the distribution.
// * Neither the name of Google Inc. nor the names of its contributors
//   may be used to endorse or promote products derived from this software
//   without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include "fs.h"

#if defined(HAVE_CONFIG_H)
#   include "config.h"
#endif

#if defined(HAVE_UNMOUNT)
#   include <sys/param.h>
#   include <sys/mount.h>
#endif
#include <sys/stat.h>
#include <sys/wait.h>

#include <assert.h>
#include <dirent.h>
#include <err.h>
#include <errno.h>
#include <stdarg.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>

#include "defs.h"
#include "error.h"


/// Specifies if a real unmount(2) is available.
///
/// We use this as a constant instead of a macro so that we can compile both
/// versions of the unmount code unconditionally.  This is a way to prevent
/// compilation bugs going unnoticed for long.
static const bool have_unmount2 =
#if defined(HAVE_UNMOUNT)
    true;
#else
    false;
#endif


#if !defined(UMOUNT)
/// Fake replacement value to the path to umount(8).
#   define UMOUNT "do-not-use-this-value"
#else
#   if defined(HAVE_UNMOUNT)
#       error "umount(8) detected when unmount(2) is also available"
#   endif
#endif


#if !defined(HAVE_UNMOUNT)
/// Fake unmount(2) function for systems without it.
///
/// This is only provided to allow our code to compile in all platforms
/// regardless of whether they actually have an unmount(2) or not.
///
/// \param unused_path The mount point to be unmounted.
/// \param unused_flags The flags to the unmount(2) call.
///
/// \return -1 to indicate error, although this should never happen.
static int
unmount(const char* KYUA_DEFS_UNUSED_PARAM(path),
        const int KYUA_DEFS_UNUSED_PARAM(flags))
{
    assert(false);
    return -1;
}
#endif


/// Scans a directory and executes a callback on each entry.
///
/// \param directory The directory to scan.
/// \param callback The function to execute on each entry.
/// \param argument A cookie to pass to the callback function.
///
/// \return True if the directory scan and the calls to the callback function
/// are all successful; false otherwise.
///
/// \note Errors are logged to stderr and do not stop the algorithm.
static bool
try_iterate_directory(const char* directory,
                      bool (*callback)(const char*, const void*),
                      const void* argument)
{
    bool ok = true;

    DIR* dirp = opendir(directory);
    if (dirp == NULL) {
        warn("opendir(%s) failed", directory);
        ok &= false;
    } else {
        struct dirent* dp;
        while ((dp = readdir(dirp)) != NULL) {
            const char* name = dp->d_name;
            if (strcmp(name, ".") == 0 || strcmp(name, "..") == 0)
                continue;

            char* subdir;
            const kyua_error_t error = kyua_fs_concat(&subdir, directory, name,
                                                      NULL);
            if (kyua_error_is_set(error)) {
                kyua_error_free(error);
                warn("path concatenation failed");
                ok &= false;
            } else {
                ok &= callback(subdir, argument);
                free(subdir);
            }
        }
        closedir(dirp);
    }

    return ok;
}


/// Stats a file, without following links.
///
/// \param path The file to stat.
/// \param [out] sb Pointer to the stat structure in which to place the result.
///
/// \return The stat structure on success; none on failure.
///
/// \note Errors are logged to stderr.
static bool
try_stat(const char* path, struct stat* sb)
{
    if (lstat(path, sb) == -1) {
        warn("lstat(%s) failed", path);
        return false;
    } else
        return true;
}


/// Removes a directory.
///
/// \param path The directory to remove.
///
/// \return True on success; false otherwise.
///
/// \note Errors are logged to stderr.
static bool
try_rmdir(const char* path)
{
    if (rmdir(path) == -1) {
        warn("rmdir(%s) failed", path);
        return false;
    } else
        return true;
}


/// Removes a file.
///
/// \param path The file to remove.
///
/// \return True on success; false otherwise.
///
/// \note Errors are logged to stderr.
static bool
try_unlink(const char* path)
{
    if (unlink(path) == -1) {
        warn("unlink(%s) failed", path);
        return false;
    } else
        return true;
}


/// Unmounts a mount point.
///
/// \param path The location to unmount.
///
/// \return True on success; false otherwise.
///
/// \note Errors are logged to stderr.
static bool
try_unmount(const char* path)
{
    const kyua_error_t error = kyua_fs_unmount(path);
    if (kyua_error_is_set(error)) {
        kyua_error_warn(error, "Cannot unmount %s", path);
        kyua_error_free(error);
        return false;
    } else
        return true;
}


/// Attempts to weaken the permissions of a file.
///
/// \param path The file to unprotect.
///
/// \return True on success; false otherwise.
///
/// \note Errors are logged to stderr.
static bool
try_unprotect(const char* path)
{
    static const mode_t new_mode = 0700;

    if (chmod(path, new_mode) == -1) {
        warnx("chmod(%s, %04o) failed", path, new_mode);
        return false;
    } else
        return true;
}


/// Attempts to weaken the permissions of a symbolic link.
///
/// \param path The symbolic link to unprotect.
///
/// \return True on success; false otherwise.
///
/// \note Errors are logged to stderr.
static bool
try_unprotect_symlink(const char* path)
{
    static const mode_t new_mode = 0700;

#if HAVE_WORKING_LCHMOD
    if (lchmod(path, new_mode) == -1) {
        warnx("lchmod(%s, %04o) failed", path, new_mode);
        return false;
    } else
        return true;
#else
    warnx("lchmod(%s, %04o) failed; system call not implemented", path,
          new_mode);
    return false;
#endif
}


/// Traverses a hierarchy unmounting any mount points in it.
///
/// \param current_path The file or directory to traverse.
/// \param raw_parent_sb The stat structure of the enclosing directory.
///
/// \return True on success; false otherwise.
///
/// \note Errors are logged to stderr and do not stop the algorithm.
static bool
recursive_unmount(const char* current_path, const void* raw_parent_sb)
{
    const struct stat* parent_sb = raw_parent_sb;

    struct stat current_sb;
    bool ok = try_stat(current_path, &current_sb);
    if (ok) {
        if (S_ISDIR(current_sb.st_mode)) {
            assert(!S_ISLNK(current_sb.st_mode));
            ok &= try_iterate_directory(current_path, recursive_unmount,
                                        &current_sb);
        }

        if (current_sb.st_dev != parent_sb->st_dev)
            ok &= try_unmount(current_path);
    }

    return ok;
}


/// Traverses a hierarchy and removes all of its contents.
///
/// This honors mount points: when a mount point is encountered, it is traversed
/// in search for other mount points, but no files within any of these are
/// removed.
///
/// \param current_path The file or directory to traverse.
/// \param raw_parent_sb The stat structure of the enclosing directory.
///
/// \return True on success; false otherwise.
///
/// \note Errors are logged to stderr and do not stop the algorithm.
static bool
recursive_cleanup(const char* current_path, const void* raw_parent_sb)
{
    const struct stat* parent_sb = raw_parent_sb;

    struct stat current_sb;
    bool ok = try_stat(current_path, &current_sb);
    if (ok) {
        // Weakening the protections of a file is just a best-effort operation.
        // If this fails, we may still be able to do the file/directory removal
        // later on, so ignore any failures from try_unprotect().
        //
        // One particular case in which this fails is if try_unprotect() is run
        // on a symbolic link that points to a file for which the unprotect is
        // not possible, and lchmod(3) is not available.
        if (S_ISLNK(current_sb.st_mode))
            try_unprotect_symlink(current_path);
        else
            try_unprotect(current_path);

        if (current_sb.st_dev != parent_sb->st_dev) {
            ok &= recursive_unmount(current_path, parent_sb);
            if (ok)
                ok &= recursive_cleanup(current_path, parent_sb);
        } else {
            if (S_ISDIR(current_sb.st_mode)) {
                assert(!S_ISLNK(current_sb.st_mode));
                ok &= try_iterate_directory(current_path, recursive_cleanup,
                                            &current_sb);
                ok &= try_rmdir(current_path);
            } else {
                ok &= try_unlink(current_path);
            }
        }
    }

    return ok;
}


/// Unmounts a file system using unmount(2).
///
/// \pre unmount(2) must be available; i.e. have_unmount2 must be true.
///
/// \param mount_point The file system to unmount.
///
/// \return An error object.
static kyua_error_t
unmount_with_unmount2(const char* mount_point)
{
    assert(have_unmount2);

    if (unmount(mount_point, 0) == -1) {
        return kyua_libc_error_new(errno, "unmount(%s) failed",
                                   mount_point);
    }

    return kyua_error_ok();
}


/// Unmounts a file system using umount(8).
///
/// \pre umount(2) must not be available; i.e. have_unmount2 must be false.
///
/// \param mount_point The file system to unmount.
///
/// \return An error object.
static kyua_error_t
unmount_with_umount8(const char* mount_point)
{
    assert(!have_unmount2);

    const pid_t pid = fork();
    if (pid == -1) {
        return kyua_libc_error_new(errno, "fork() failed");
    } else if (pid == 0) {
        const int ret = execlp(UMOUNT, "umount", mount_point, NULL);
        assert(ret == -1);
        err(EXIT_FAILURE, "Failed to execute " UMOUNT);
    }

    kyua_error_t error = kyua_error_ok();
    int status;
    if (waitpid(pid, &status, 0) == -1) {
        error = kyua_libc_error_new(errno, "waitpid(%d) failed", pid);
    } else {
        if (WIFEXITED(status)) {
            if (WEXITSTATUS(status) == EXIT_SUCCESS)
                assert(!kyua_error_is_set(error));
            else {
                error = kyua_libc_error_new(EBUSY, "unmount(%s) failed",
                                            mount_point);
            }
        } else
            error = kyua_libc_error_new(EFAULT, "umount(8) crashed");
    }
    return error;
}


/// Recursively removes a directory.
///
/// \param root The directory or file to remove.  Cannot be a mount point.
///
/// \return An error object.
kyua_error_t
kyua_fs_cleanup(const char* root)
{
    struct stat current_sb;
    bool ok = try_stat(root, &current_sb);
    if (ok)
        ok &= recursive_cleanup(root, &current_sb);

    if (!ok) {
        warnx("Cleanup of '%s' failed", root);
        return kyua_libc_error_new(EPERM, "Cleanup of %s failed", root);
    } else
        return kyua_error_ok();
}


/// Concatenates a set of strings to form a path.
///
/// \param [out] output Pointer to a dynamically-allocated string that will hold
///     the resulting path, if all goes well.
/// \param first First component of the path to concatenate.
/// \param ... All other components to concatenate.
///
/// \return An error if there is not enough memory to fulfill the request; OK
/// otherwise.
kyua_error_t
kyua_fs_concat(char** const output, const char* first, ...)
{
    va_list ap;
    const char* component;

    va_start(ap, first);
    size_t length = strlen(first) + 1;
    while ((component = va_arg(ap, const char*)) != NULL) {
        length += 1 + strlen(component);
    }
    va_end(ap);

    *output = (char*)malloc(length);
    if (output == NULL)
        return kyua_oom_error_new();
    char* iterator = *output;

    int added_size;
    added_size = snprintf(iterator, length, "%s", first);
    iterator += added_size; length -= added_size;

    va_start(ap, first);
    while ((component = va_arg(ap, const char*)) != NULL) {
        added_size = snprintf(iterator, length, "/%s", component);
        iterator += added_size; length -= added_size;
    }
    va_end(ap);

    return kyua_error_ok();
}


/// Queries the path to the current directory.
///
/// \param [out] out_cwd Dynamically-allocated pointer to a string holding the
///     current path.  The caller must use free() to release it.
///
/// \return An error object.
kyua_error_t
kyua_fs_current_path(char** out_cwd)
{
    char* cwd;
#if defined(HAVE_GETCWD_DYN)
    cwd = getcwd(NULL, 0);
#else
    {
        const char* static_cwd = ::getcwd(NULL, MAXPATHLEN);
        const kyua_error_t error = kyua_fs_concat(&cwd, static_cwd, NULL);
        if (kyua_error_is_set(error))
            return error;
    }
#endif
    if (cwd == NULL) {
        return kyua_libc_error_new(errno, "getcwd() failed");
    } else {
        *out_cwd = cwd;
        return kyua_error_ok();
    }
}


/// Converts a path to absolute.
///
/// \param original The path to convert; may already be absolute.
/// \param [out] output Pointer to a dynamically-allocated string that will hold
///     the absolute path, if all goes well.
///
/// \return An error if there is not enough memory to fulfill the request; OK
/// otherwise.
kyua_error_t
kyua_fs_make_absolute(const char* original, char** const output)
{
    if (original[0] == '/') {
        *output = (char*)malloc(strlen(original) + 1);
        if (output == NULL)
            return kyua_oom_error_new();
        strcpy(*output, original);
        return kyua_error_ok();
    } else {
        char* current_path;
        kyua_error_t error;

        error = kyua_fs_current_path(&current_path);
        if (kyua_error_is_set(error))
            return error;

        error = kyua_fs_concat(output, current_path, original, NULL);
        free(current_path);
        return error;
    }
}


/// Unmounts a file system.
///
/// \param mount_point The file system to unmount.
///
/// \return An error object.
kyua_error_t
kyua_fs_unmount(const char* mount_point)
{
    kyua_error_t error;

    // FreeBSD's unmount(2) requires paths to be absolute.  To err on the side
    // of caution, let's make it absolute in all cases.
    char* abs_mount_point;
    error = kyua_fs_make_absolute(mount_point, &abs_mount_point);
    if (kyua_error_is_set(error))
        goto out;

    static const int unmount_retries = 3;
    static const int unmount_retry_delay_seconds = 1;

    int retries = unmount_retries;
retry:
    if (have_unmount2) {
        error = unmount_with_unmount2(abs_mount_point);
    } else {
        error = unmount_with_umount8(abs_mount_point);
    }
    if (kyua_error_is_set(error)) {
        assert(kyua_error_is_type(error, "libc"));
        if (kyua_libc_error_errno(error) == EBUSY && retries > 0) {
            kyua_error_warn(error, "%s busy; unmount retries left %d",
                            abs_mount_point, retries);
            kyua_error_free(error);
            retries--;
            sleep(unmount_retry_delay_seconds);
            goto retry;
        }
    }

out:
    free(abs_mount_point);
    return error;
}