#!/usr/bin/python3
"""
    Copyright (C) 2021 Michael Ablassmeier <abi@grinser.de>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.
"""
import os
import sys
import tempfile
import logging
import argparse
from argparse import Namespace
from typing import List, Union, Dict
import pprint
import json
from libvirtnbdbackup import argopt
from libvirtnbdbackup import __version__
from libvirtnbdbackup import nbdcli
from libvirtnbdbackup import virt
from libvirtnbdbackup import qemu
from libvirtnbdbackup import output
from libvirtnbdbackup.common import common as lib
from libvirtnbdbackup.logcount import logCount
from libvirtnbdbackup.sparsestream import streamer
from libvirtnbdbackup.sparsestream import types
from libvirtnbdbackup.exceptions import RestoreError, UntilCheckpointReached
from libvirtnbdbackup.sparsestream import exceptions
from libvirtnbdbackup.qemu.exceptions import ProcessError
from libvirtnbdbackup.output.exceptions import OutputException
from libvirtnbdbackup.virt.client import DomainDisk


def dump(args: Namespace, stream, dataFiles: List[str]) -> bool:
    """Dump stream contents to json output"""
    logging.info("Dumping saveset meta information")
    entries = []
    for dataFile in dataFiles:
        if args.disk is not None and not os.path.basename(dataFile).startswith(
            args.disk
        ):
            continue
        logging.info(dataFile)

        sourceFile = dataFile
        if args.sequence:
            sourceFile = f"{args.input}/{dataFile}"

        try:
            meta = getHeader(sourceFile, stream)
        except RestoreError:
            continue

        entries.append(meta)

        if lib.isCompressed(meta):
            logging.info("Compressed stream found: [%s].", meta["compressionMethod"])

    print(json.dumps(entries, indent=4))

    return True


def restoreData(
    args: Namespace, stream, dataFile: str, targetFile: str, connection
) -> bool:
    """Restore data for disk"""
    sTypes = types.SparseStreamTypes()

    try:
        reader = open(dataFile, "rb")
    except OSError as errmsg:
        logging.critical("Failed to open backup file for reading: [%s].", errmsg)
        raise RestoreError from errmsg

    try:
        kind, start, length = stream.readFrame(reader)
        meta = stream.loadMetadata(reader.read(length))
    except exceptions.StreamFormatException as errmsg:
        logging.fatal(errmsg)
        raise RestoreError from errmsg

    trailer = None
    if lib.isCompressed(meta) is True:
        trailer = stream.readCompressionTrailer(reader)
        logging.info("Found compression trailer.")
        logging.debug("%s", trailer)

    if meta["dataSize"] == 0:
        logging.info("File [%s] contains no dirty blocks, skipping.", dataFile)
        return True

    logging.info(
        "Applying data from backup file [%s] to target file [%s].", dataFile, targetFile
    )
    pprint.pprint(meta)
    assert reader.read(len(sTypes.TERM)) == sTypes.TERM

    progressBar = lib.progressBar(
        meta["dataSize"], f"restoring disk [{meta['diskName']}]", args
    )
    dataSize = 0
    dataBlockCnt = 0
    while True:
        try:
            kind, start, length = stream.readFrame(reader)
        except exceptions.StreamFormatException as err:
            logging.error("Cant read stream at pos: [%s]: [%s]", reader.tell(), err)
            raise RestoreError from err
        if kind == sTypes.ZERO:
            logging.debug("Zero segment from [%s] length: [%s]", start, length)
        elif kind == sTypes.DATA:
            logging.debug(
                "Processing data segment from [%s] length: [%s]", start, length
            )

            originalSize = length
            if trailer:
                logging.debug("Block: [%s]", dataBlockCnt)
                logging.debug("Original block size: [%s]", length)
                length = trailer[dataBlockCnt]
                logging.debug("Compressed block size: [%s]", length)

            if originalSize >= connection.maxRequestSize:
                logging.debug(
                    "Chunked read/write, start: [%s], len: [%s]", start, length
                )
                try:
                    written = lib.readChunk(
                        reader,
                        start,
                        length,
                        connection,
                        lib.isCompressed(meta),
                    )
                except Exception as e:
                    logging.exception(e)
                    raise RestoreError from e
                logging.debug("Wrote: [%s]", written)
            else:
                try:
                    data = reader.read(length)
                    if lib.isCompressed(meta):
                        data = lib.lz4DecompressFrame(data)
                    connection.nbd.pwrite(data, start)
                    written = len(data)
                except Exception as e:
                    logging.exception(e)
                    raise RestoreError from e

            assert reader.read(len(sTypes.TERM)) == sTypes.TERM
            dataSize += originalSize
            progressBar.update(written)
            dataBlockCnt += 1
        elif kind == sTypes.STOP:
            progressBar.close()
            if dataSize != meta["dataSize"]:
                logging.error(
                    "Restored data size does not match [%s] != [%s]",
                    dataSize,
                    meta["dataSize"],
                )
                raise RestoreError("Data size mismatch")
            break

    logging.info("End of stream, [%s] bytes of data processed", dataSize)
    if meta["checkpointName"] == args.until:
        logging.info("Reached checkpoint [%s], stopping", args.until)
        raise UntilCheckpointReached

    return True


def restoreSequence(
    args: Namespace, dataFiles: List[str], virtClient: virt.client
) -> bool:
    """Reconstruct image from a given set of data files"""
    stream = streamer.SparseStream(types)

    result: bool = False

    sourceFile = f"{args.input}/{dataFiles[0]}"
    meta = getHeader(sourceFile, stream)
    if not meta:
        return result

    diskName = meta["diskName"]
    targetFile = f"{args.output}/{diskName}"
    if lib.exists(args, targetFile):
        raise RestoreError(f"Targetfile {targetFile} already exists.")

    try:
        createDiskFile(args, meta, targetFile, args.sshClient)
    except RestoreError as errmsg:
        raise errmsg

    try:
        connection = startNbd(args, diskName, targetFile, virtClient)
    except ProcessError as errmsg:
        logging.error(errmsg)
        raise RestoreError("Failed to start NBD server.")

    for disk in dataFiles:
        sourceFile = f"{args.input}/{disk}"
        result = writeData(args, stream, sourceFile, targetFile, connection)

    connection.disconnect()

    return result


def writeData(args: Namespace, stream, disk: str, targetFile: str, connection) -> bool:
    """Restore the data stream to the target file"""
    diskState = False
    diskState = restoreData(args, stream, disk, targetFile, connection)
    # no data has been processed
    if diskState is None:
        diskState = True

    return diskState


def getQcowConfig(args: Namespace, meta: Dict[str, str]) -> List[str]:
    """Check if backup includes exported qcow config and return a list
    of options passed to qemu-img create command"""
    opt: List[str] = []
    qcowConfig = None
    qcowConfigFile = lib.getLatest(args.input, f"{meta['diskName']}*.qcow.json*", -1)
    if not qcowConfigFile:
        logging.info("No QCOW image config found, will use default options.")
        return opt

    lastConfigFile = qcowConfigFile[0]

    try:
        with output.openfile(lastConfigFile, "rb") as qFh:
            qcowConfig = json.loads(qFh.read().decode())
        logging.info("Using QCOW options from backup file: [%s]", lastConfigFile)
    except (
        OutputException,
        json.decoder.JSONDecodeError,
    ) as errmsg:
        logging.warning(
            "Unable to load original QCOW image config, using defaults: [%s].",
            errmsg,
        )
        return opt

    try:
        opt.append("-o")
        opt.append(f"compat={qcowConfig['format-specific']['data']['compat']}")
    except KeyError as errmsg:
        logging.warning("Unable apply QCOW specific compat option: [%s]", errmsg)

    try:
        opt.append("-o")
        opt.append(f"cluster_size={qcowConfig['cluster-size']}")
    except KeyError as errmsg:
        logging.warning("Unable apply QCOW specific cluster_size option: [%s]", errmsg)

    try:
        if qcowConfig["format-specific"]["data"]["lazy-refcounts"]:
            opt.append("-o")
            opt.append("lazy_refcounts=on")
    except KeyError as errmsg:
        logging.warning(
            "Unable apply QCOW specific lazy_refcounts option: [%s]", errmsg
        )

    return opt


def createDiskFile(args: Namespace, meta: Dict[str, str], targetFile: str, sshClient):
    """Create target image file"""
    logging.info(
        "Create virtual disk [%s] format: [%s] size: [%s]",
        targetFile,
        meta["diskFormat"],
        meta["virtualSize"],
    )

    options = getQcowConfig(args, meta)
    if lib.exists(args, targetFile):
        logging.error("Target file already exists: [%s], wont overwrite.", targetFile)
        raise RestoreError

    qFh = qemu.util(meta["diskName"])
    try:
        qFh.create(
            targetFile, int(meta["virtualSize"]), meta["diskFormat"], options, sshClient
        )
    except ProcessError as e:
        logging.error("Cant create restore target: [%s]", e)
        raise RestoreError from e


def getHeader(diskFile: str, stream) -> Dict[str, str]:
    """Read header from data file"""
    try:
        return lib.dumpMetaData(diskFile, stream)
    except exceptions.StreamFormatException as errmsg:
        raise RestoreError(
            f"Reading metadata from [{diskFile}] failed: [{errmsg}]"
        ) from errmsg
    except OutputException as errmsg:
        raise RestoreError(
            f"Reading data file [{diskFile}] failed: [{errmsg}]"
        ) from errmsg


def startNbd(
    args: Namespace, exportName: str, targetFile: str, virtClient: virt.client
):
    """Start NBD service required for restore, either remote or local"""
    qFh = qemu.util(exportName)
    cType: Union[nbdcli.TCP, nbdcli.Unix]
    if not virtClient.remoteHost:
        logging.info("Starting local NBD server on socket: [%s]", args.socketfile)
        proc = qFh.startRestoreNbdServer(targetFile, args.socketfile)
        cType = nbdcli.Unix(exportName, "", args.socketfile)
    else:
        remoteIP = virtClient.remoteHost
        if args.nbd_ip != "":
            remoteIP = args.nbd_ip
        logging.info(
            "Starting remote NBD server on socket: [%s:%s], port: [%s]",
            remoteIP,
            args.socketfile,
            args.nbd_port,
        )
        proc = qFh.startRemoteRestoreNbdServer(args, targetFile)
        cType = nbdcli.TCP(exportName, "", remoteIP, args.tls, args.nbd_port)

    nbdClient = nbdcli.client(cType)
    logging.info("Started NBD server, PID: [%s]", proc.pid)
    return nbdClient.connect()


def readConfig(ConfigFile: str) -> str:
    """Read saved virtual machine config'"""
    try:
        return output.openfile(ConfigFile, "rb").read().decode()
    except:
        logging.error("Cant read config file: [%s]", ConfigFile)
        raise


def getDisksFromConfig(
    args: Namespace, vmConfig: str, virtClient: virt.client
) -> List[DomainDisk]:
    """Parse disk information from latest config file
    contained in the backup directory
    """
    return virtClient.getDomainDisks(args, vmConfig)


def checkBackingStore(args: Namespace, disk: DomainDisk) -> None:
    """If an virtual machine was running on an snapshot image,
    warn user, the virtual machine configuration has to be
    adjusted before starting the VM is possible"""
    if len(disk.backingstores) > 0 and not args.adjust_config:
        logging.warning(
            "Target image [%s] seems to be a snapshot image.", disk.filename
        )
        logging.warning("Target virtual machine configuration must be altered!")
        logging.warning("Configured backing store images must be changed.")


def restoreFiles(args: Namespace, vmConfig: str, virtClient: virt.client) -> None:
    """Notice user if backed up vm had loader / nvram"""
    config = readConfig(vmConfig)
    info = virtClient.getDomainInfo(config)

    for setting, val in info.items():
        f = lib.getLatest(args.input, f"*{os.path.basename(val)}*", -1)
        if lib.exists(args, val):
            logging.info(
                "File [%s]: for boot option [%s] already exists, skipping.",
                val,
                setting,
            )
            continue

        logging.info(
            "Restoring configured file [%s] for boot option [%s]", val, setting
        )
        lib.copy(args, f[0], val)


def setTargetFile(args: Namespace, disk: DomainDisk) -> str:
    """Based on disk information, set target file
    to write"""
    if disk.filename is not None:
        targetFile = f"{args.output}/{disk.filename}"
    else:
        targetFile = f"{args.output}/{disk.target}"

    return targetFile


def restore(args: Namespace, ConfigFile: str, virtClient: virt.client) -> bytes:
    """Handle restore operation"""
    stream = streamer.SparseStream(types)
    vmConfig = readConfig(ConfigFile)
    vmDisks = getDisksFromConfig(args, vmConfig, virtClient)
    if not vmDisks:
        raise RestoreError("Unable to parse disks from config")

    restConfig: bytes = b""
    for disk in vmDisks:
        if args.disk not in (None, disk.target):
            logging.info("Skipping disk [%s] for restore", disk.target)
            continue

        restoreDisk = lib.getLatest(args.input, f"{disk.target}*.data")
        logging.debug("Restoring disk: [%s]", restoreDisk)
        if len(restoreDisk) < 1:
            logging.warning(
                "No backup file for disk [%s] found, assuming it has been excluded.",
                disk.target,
            )
            if args.adjust_config is True:
                restConfig = virtClient.adjustDomainConfigRemoveDisk(
                    vmConfig, disk.target
                )
            continue

        targetFile = setTargetFile(args, disk)

        if args.raw and disk.format == "raw":
            logging.info("Restoring raw image to [%s]", targetFile)
            lib.copy(args, restoreDisk[0], targetFile)
            continue

        if "full" not in restoreDisk[0] and "copy" not in restoreDisk[0]:
            logging.error(
                "[%s]: Unable to locate base full or copy backup.", restoreDisk[0]
            )
            raise RestoreError("Failed to locate backup.")

        meta = getHeader(restoreDisk[0], stream)

        try:
            createDiskFile(args, meta, targetFile, args.sshClient)
        except RestoreError as errmsg:
            raise RestoreError("Creating target image failed.") from errmsg

        try:
            connection = startNbd(args, meta["diskName"], targetFile, virtClient)
        except ProcessError as errmsg:
            logging.error(errmsg)
            raise RestoreError("Failed to start NBD server.")

        for dataFile in restoreDisk:
            try:
                writeData(args, stream, dataFile, targetFile, connection)
            except UntilCheckpointReached:
                break
            except RestoreError:
                break

        checkBackingStore(args, disk)
        if args.adjust_config is True:
            restConfig = virtClient.adjustDomainConfig(args, disk, vmConfig, targetFile)
        else:
            restConfig = vmConfig.encode()

        connection.disconnect()

    return restConfig


def restoreConfig(args: Namespace, vmConfig: str, adjustedConfig: bytes) -> None:
    """Restore either original or adjusted vm configuration
    to new directory"""
    targetFile = f"{args.output}/{os.path.basename(vmConfig)}"
    if args.adjust_config is True:
        if args.sshClient:
            with tempfile.NamedTemporaryFile(delete=True) as fh:
                fh.write(adjustedConfig)
                lib.copy(args, fh.name, targetFile)
        else:
            with output.openfile(targetFile, "wb") as cnf:
                cnf.write(adjustedConfig)
            logging.info("Adjusted config placed in: [%s]", targetFile)
        if args.define is False:
            logging.info("Use 'virsh define %s' to define VM", targetFile)
    else:
        lib.copy(args, vmConfig, targetFile)
        logging.info("Copied original vm config to [%s]", targetFile)
        logging.info("Note: virtual machine config must be adjusted manually.")


def main() -> None:
    """main function"""
    parser = argparse.ArgumentParser(
        description="Restore virtual machine disks",
        epilog=(
            "Examples:\n"
            "   # Dump backup metadata:\n"
            "\t%(prog)s -i /backup/ -o dump\n"
            "   # Complete restore with all disks:\n"
            "\t%(prog)s -i /backup/ -o /target\n"
            "   # Complete restore, adjust config and redefine vm after restore:\n"
            "\t%(prog)s -cD -i /backup/ -o /target\n"
            "   # Complete restore, adjust config and redefine vm with name 'foo':\n"
            "\t%(prog)s -cD --name foo -i /backup/ -o /target\n"
            "   # Restore only disk 'vda':\n"
            "\t%(prog)s -i /backup/ -o /target -d vda\n"
            "   # Point in time restore:\n"
            "\t%(prog)s -i /backup/ -o /target --until virtnbdbackup.2\n"
            "   # Restore and process specific file sequence:\n"
            "\t%(prog)s -i /backup/ -o /target "
            "--sequence vdb.full.data,vdb.inc.virtnbdbackup.1.data\n"
            "   # Restore to remote system:\n"
            "\t%(prog)s -U qemu+ssh://root@remotehost/system"
            " --ssh-user root -i /backup/ -o /remote_target"
        ),
        formatter_class=argparse.RawTextHelpFormatter,
    )
    opt = parser.add_argument_group("General options")
    opt.add_argument(
        "-a",
        "--action",
        required=False,
        type=str,
        choices=["dump", "restore"],
        default="restore",
        help="Action to perform: (default: %(default)s)",
    )
    opt.add_argument(
        "-i",
        "--input",
        required=True,
        type=str,
        help="Directory including a backup set",
    )
    opt.add_argument(
        "-o", "--output", required=True, type=str, help="Restore target directory"
    )
    opt.add_argument(
        "-u",
        "--until",
        required=False,
        type=str,
        help="Restore only until checkpoint, point in time restore.",
    )
    opt.add_argument(
        "-s",
        "--sequence",
        required=False,
        type=str,
        default=None,
        help="Restore image based on specified backup files.",
    )
    opt.add_argument(
        "-d",
        "--disk",
        required=False,
        type=str,
        default=None,
        help="Process only disk matching target dev name. (default: %(default)s)",
    )
    opt.add_argument(
        "-n",
        "--noprogress",
        required=False,
        action="store_true",
        default=False,
        help="Disable progress bar",
    )
    opt.add_argument(
        "-f",
        "--socketfile",
        default=f"/var/tmp/virtnbdbackup.{os.getpid()}",
        type=str,
        help="Use specified file for NBD Server socket (default: %(default)s)",
    )
    opt.add_argument(
        "-r",
        "--raw",
        default=False,
        action="store_true",
        help="Copy raw images as is during restore. (default: %(default)s)",
    )
    opt.add_argument(
        "-c",
        "--adjust-config",
        default=False,
        action="store_true",
        help="Adjust vm configuration during restore. (default: %(default)s)",
    )
    opt.add_argument(
        "-D",
        "--define",
        default=False,
        action="store_true",
        help="Register/define VM after restore. (default: %(default)s)",
    )
    opt.add_argument(
        "-N",
        "--name",
        default=None,
        type=str,
        help="Define restored domain with specified name",
    )
    remopt = parser.add_argument_group("Remote Restore options")
    argopt.addRemoteArgs(remopt)
    logopt = parser.add_argument_group("Logging options")
    argopt.addLogArgs(logopt, parser.prog)
    debopt = parser.add_argument_group("Debug options")
    argopt.addDebugArgs(debopt)

    args = lib.argparse(parser)
    args.sshClient = None
    # default values for common usage of lib.getDomainDisks
    args.exclude = None
    args.include = args.disk
    stream = streamer.SparseStream(types)
    fileLog = lib.getLogFile(args.logfile) or sys.exit(1)
    counter = logCount()

    lib.configLogger(args, fileLog, counter)
    lib.printVersion(__version__)

    if not lib.exists(args, args.input):
        logging.error("Backup source [%s] does not exist.", args.input)
        sys.exit(1)

    dataFiles: List[str] = []
    if args.sequence is not None:
        logging.info("Using manual specified sequence of files.")
        logging.info("Disabling redefine and config adjust options.")
        args.define = False
        args.adjust_config = False
        dataFiles = args.sequence.split(",")

        if "full" not in dataFiles[0] and "copy" not in dataFiles[0]:
            logging.error("Sequence must start with full or copy backup.")
            sys.exit(1)
    else:
        dataFiles = lib.getLatest(args.input, "*.data")
        if not dataFiles:
            logging.error("No data files found in directory: [%s]", args.input)
            sys.exit(1)

    if args.action == "dump" or args.output == "dump":
        dump(args, stream, dataFiles)
        sys.exit(0)

    if args.action == "restore":
        if args.define is True:
            args.adjust_config = True

        virtClient = virt.client(args)
        if virtClient.remoteHost:
            args.sshClient = lib.sshSession(args, virtClient.remoteHost)
            if not args.sshClient:
                logging.error("Remote restore detected but ssh session setup failed")
                sys.exit(1)
            args.sshClient.sftp.mkdir(args.output)
        else:
            output.target.Directory(args.output)

        ConfigFiles = lib.getLatest(args.input, "vmconfig*.xml", -1)
        if not ConfigFiles:
            logging.error("No domain config file found")
            sys.exit(1)

        ConfigFile = ConfigFiles[0]
        logging.info("Using latest config file: [%s]", ConfigFile)

        autoStart = False
        if lib.getLatest(args.input, "autostart.*", -1):
            autoStart = True

        restConfig: bytes = b""
        try:
            if args.sequence is not None:
                restoreSequence(args, dataFiles, virtClient)
            else:
                restConfig = restore(args, ConfigFile, virtClient)
        except RestoreError as errmsg:
            logging.fatal("Disk restore failed: [%s]", errmsg)
            sys.exit(1)

        restoreFiles(args, ConfigFile, virtClient)
        restoreConfig(args, ConfigFile, restConfig)
        virtClient.refreshPool(args.output)
        if args.define is True:
            if not virtClient.defineDomain(restConfig, autoStart):
                sys.exit(1)


if __name__ == "__main__":
    main()
