#!/usr/bin/python3

# SPDX-License-Identifier: GPL-2.0-or-later

"""
This script updates the copyright headers on all files project-wide. It derives
the authorship and copyright years information from the git history of the
project; hence, this script must be run within a git clone of the project's
repository. Options are available to specify the license text file, the files
to edit, and files to exclude. Default options can be specified in the
containing project's setup.cfg under [{SETUP_SECTION}]
"""

from __future__ import annotations

import os
import sys
assert sys.version_info >= (3, 6), 'Script requires Python 3.6+'
import tempfile
import typing as t
from argparse import ArgumentParser, Namespace
from configparser import ConfigParser
from operator import attrgetter
from itertools import groupby
from datetime import datetime
from subprocess import Popen, PIPE, DEVNULL
from pathlib import Path
from fnmatch import fnmatch


PROJECT_ROOT = (Path(__file__).parent / '..').resolve()
SETUP_SECTION = str(Path(__file__).name) + ':settings'
SPDX_PREFIX = 'SPDX-License-Identifier:'
COPYRIGHT_PREFIX = 'Copyright (c)'


def main(args: t.List[str] = None):
    if args is None:
        args = sys.argv[1:]
    config = get_config(args)

    writer = CopyWriter.from_config(config)
    for path, copyrights in get_copyrights(config.include, config.exclude):
        print(f'Re-writing {path}...')
        copyrights = sorted(
            copyrights, reverse=True, key=lambda c: (max(c.years), c.author))
        with AtomicReplaceFile(path, encoding='utf-8') as target:
            with path.open('r') as source:
                for chunk in writer.transform(source, copyrights):
                    target.write(chunk)


def get_config(args: t.List[str]) -> Namespace:
    config = ConfigParser(
        defaults={
            'include': '**/*',
            'exclude': '',
            'license': 'LICENSE.txt',
            'preamble': '',
            'strip_preamble': 'false',
            'spdx_prefix': SPDX_PREFIX,
            'copy_prefix': COPYRIGHT_PREFIX,
        },
        delimiters=('=',), default_section=SETUP_SECTION,
        empty_lines_in_values=False, interpolation=None,
        converters={'list': lambda s: s.strip().splitlines()})
    config.read(PROJECT_ROOT / 'setup.cfg')
    sect = config[SETUP_SECTION]
    # Resolve license default relative to setup.cfg
    if sect['license']:
        sect['license'] = str(PROJECT_ROOT / sect['license'])

    parser = ArgumentParser(description=__doc__.format(**globals()))
    parser.add_argument(
        '-i', '--include', action='append', metavar='GLOB',
        default=sect.getlist('include'),
        help="The set of patterns that a file must match to be included in "
        "the set of files to re-write. Can be specified multiple times to "
        "add several patterns. Default: %(default)r")
    parser.add_argument(
        '-e', '--exclude', action='append', metavar='GLOB',
        default=sect.getlist('exclude'),
        help="The set of patterns that a file must *not* match to be included "
        "in the set of files to re-write. Can be specified multiple times to "
        "add several patterns. Default: %(default)r")
    parser.add_argument(
        '-l', '--license', action='store', type=Path, metavar='PATH',
        default=sect['license'],
        help="The file containing the project's license text. If this file "
        "contains a SPDX-License-Identifier line (in addition to the license "
        "text itself), then matching license text found in source files will "
        "be replaced by the SPDX-License-Identifier line (appropriately "
        "commented). Default: %(default)s")
    parser.add_argument(
        '-p', '--preamble', action='append', metavar='STR',
        default=sect.getlist('preamble'),
        help="The line(s) of text to insert before the copyright attributions "
        "in source files. This is typically a brief description of the "
        "project. Can be specified multiple times to add several lines. "
        "Default: %(default)r")
    parser.add_argument(
        '-S', '--spdx-prefix', action='store', metavar='STR',
        default=sect['spdx_prefix'],
        help="The prefix on the line in the license file, and within comments "
        "of source files that identifies the appropriate license from the "
        "SPDX list. Default: %(default)r")
    parser.add_argument(
        '-C', '--copy-prefix', action='store', metavar='STR',
        default=sect['copy_prefix'],
        help="The prefix before copyright attributions in source files. "
        "Default: %(default)r")
    parser.add_argument(
        '--no-strip-preamble', action='store_false', dest='strip_preamble')
    parser.add_argument(
        '--strip-preamble', action='store_true',
        default=sect.getboolean('strip-preamble'),
        help="If enabled, any existing preamble matching that specified "
        "by --preamble will be removed. This can be used to change the "
        "preamble text in files by first specifying the old preamble with "
        "this option, then running a second time with the new preamble")

    ns = parser.parse_args(args)
    ns.include = set(ns.include)
    ns.exclude = set(ns.exclude)
    return ns


class Copyright(t.NamedTuple):
    author: str
    email:  str
    years:  t.Set[int]

    def __str__(self):
        if len(self.years) > 1:
            years = f'{min(self.years)}-{max(self.years)}'
        else:
            years = f'{min(self.years)}'
        return f'{years} {self.author} <{self.email}>'


def get_copyrights(include: t.Set[str], exclude: t.Set[str])\
        -> t.Iterator[t.Tuple[Path, t.Iterable[Copyright]]]:
    sorted_blame = sorted(
        get_contributions(include, exclude),
        key=lambda c: (c.path, c.author, c.email)
    )
    blame_by_file = {
        path: list(file_contributions)
        for path, file_contributions in groupby(
            sorted_blame, key=attrgetter('path')
        )
    }
    for path, file_contributors in blame_by_file.items():
        it = groupby(file_contributors, key=lambda c: (c.author, c.email))
        copyrights = [
            Copyright(author, email, {y.year for y in years})
            for (author, email), years in it
        ]
        yield path, copyrights


class Contribution(t.NamedTuple):
    author: str
    email:  str
    year:   int
    path:   Path


def get_contributions(include: t.Set[str], exclude: t.Set[str])\
        -> t.Iterator[Contribution]:
    for path in get_source_paths(include, exclude):
        blame = Popen(
            ['git', 'blame', '--line-porcelain', 'HEAD', '--', str(path)],
            stdout=PIPE,
            stderr=PIPE,
            universal_newlines=True
        )
        author = email = year = None
        if blame.stdout is not None:
            for line in blame.stdout:
                if line.startswith('author '):
                    author = line.split(' ', 1)[1].rstrip()
                elif line.startswith('author-mail '):
                    email = line.split(' ', 1)[1].rstrip()
                    email = email.lstrip('<').rstrip('>')
                elif line.startswith('author-time '):
                    # Forget the timezone; we only want the year anyway
                    timestamp = int(line.split(' ', 1)[1].strip())
                    year = datetime.fromtimestamp(timestamp).year
                elif line.startswith('filename '):
                    assert author is not None
                    assert email is not None
                    assert year is not None
                    yield Contribution(
                        author=author, email=email, year=year, path=path)
                    author = email = year = None
        blame.wait()
        assert blame.returncode == 0


def get_source_paths(include: t.Set[str], exclude: t.Set[str])\
        -> t.Iterator[Path]:
    ls_tree = Popen(
        ['git', 'ls-tree', '-r', '--name-only', 'HEAD'],
        stdout=PIPE, stderr=DEVNULL, universal_newlines=True)
    if not include:
        include = {'*'}
    if ls_tree.stdout is not None:
        for filename in ls_tree.stdout:
            filename = filename.strip()
            if any(fnmatch(filename, pattern) for pattern in exclude):
                continue
            if any(fnmatch(filename, pattern) for pattern in include):
                yield Path(filename)
    ls_tree.wait()
    assert ls_tree.returncode == 0


class License(t.NamedTuple):
    ident: t.Optional[str]
    text:  t.List[str]


def get_license(path: Path, *, spdx_prefix: str = SPDX_PREFIX) -> License:
    with open(path, 'r') as f:
        lines = f.read().splitlines()

        idents = [
            line.rstrip() for line in lines
            if line.startswith(spdx_prefix)
        ]
        ident = None
        if len(idents) > 1:
            raise RuntimeError(f'More than one {spdx_prefix} line in {path}!')
        elif len(idents) == 1:
            ident = idents[0]

        body = [
            line.rstrip() for line in lines
            if not line.startswith(spdx_prefix)
        ]
        while not body[0]:
            del body[0]
        while not body[-1]:
            del body[-1]
        return License(ident, body)


class CopyWriter:
    """
    Transformer for the copyright header in source files. The :meth:`transform`
    method can be called with a file-like object as the *source* and will
    yield chunks of replacement data to be written to the replacement.
    """

    # The script's kinda dumb at this point - only handles straight-forward
    # line-based comments, not multi-line delimited styles like /*..*/
    COMMENTS = {
        '':     '#',
        '.c':   '//',
        '.cpp': '//',
        '.js':  '//',
        '.py':  '#',
        '.rst': '..',
        '.sh':  '#',
        '.sql': '--',
    }

    def __init__(self, license: Path=Path('LICENSE.txt'),
                 preamble: t.List[str]=None,
                 spdx_prefix: str=SPDX_PREFIX,
                 copy_prefix: str=COPYRIGHT_PREFIX):
        if preamble is None:
            preamble = []
        self.license = get_license(license, spdx_prefix=spdx_prefix)
        self.preamble = preamble
        self.spdx_prefix = spdx_prefix
        self.copy_prefix = copy_prefix

    @classmethod
    def from_config(cls, config: Namespace) -> CopyWriter:
        return cls(
            config.license, config.preamble,
            config.spdx_prefix, config.copy_prefix)

    def transform(self, source: t.TextIO,
                  copyrights: t.List[Copyright], *,
                  comment_prefix: str=None) -> t.Iterator[str]:
        if comment_prefix is None:
            comment_prefix = self.COMMENTS[Path(source.name).suffix]
        license_start = self.license.text[0]
        license_end = self.license.text[-1]
        state = 'header'
        empty = True
        for linenum, line in enumerate(source, start=1):
            if state == 'header':
                if linenum == 1 and line.startswith('#!'):
                    yield line
                    empty = False
                elif linenum < 3 and (
                        'fileencoding=' in line or '-*- coding:' in line):
                    yield line
                    empty = False
                elif line.rstrip() == comment_prefix:
                    pass # skip blank comment lines
                elif line.startswith(f'{comment_prefix} {self.spdx_prefix}'):
                    pass # skip existing SPDX ident
                elif line.startswith(f'{comment_prefix} {self.copy_prefix}'):
                    pass # skip existing copyright lines
                elif any(line.startswith(f'{comment_prefix} {pre_line}')
                         for pre_line in self.preamble):
                    pass # skip existing preamble
                elif line.startswith(f'{comment_prefix} {license_start}'):
                    state = 'license' # skip existing license lines
                else:
                    yield from self._generate_header(
                        copyrights, comment_prefix, empty)
                    state = 'blank'
            elif state == 'license':
                if line.startswith(f'{comment_prefix} {license_end}'):
                    yield from self._generate_header(
                        copyrights, comment_prefix, empty)
                    state = 'blank'
                    continue
            if state == 'blank':
                # Ensure there's a blank line between license and start of the
                # source body
                if line.strip():
                    yield '\n'
                yield line
                state = 'body'
            elif state == 'body':
                yield line

    def _generate_header(self, copyrights: t.Iterable[Copyright],
                         comment_prefix: str, empty: bool) -> t.Iterator[str]:
        if not empty:
            yield comment_prefix + '\n'
        for line in self.preamble:
            yield f'{comment_prefix} {line}\n'
        if self.preamble:
            yield comment_prefix + '\n'
        for copyright in copyrights:
            yield f'{comment_prefix} {self.copy_prefix} {copyright!s}\n'
        yield comment_prefix + '\n'
        if self.license.ident:
            yield f'{comment_prefix} {self.license.ident}\n'
        else:
            for line in self.license.text:
                if line:
                    yield f'{comment_prefix} {line}\n'
                else:
                    yield comment_prefix + '\n'


class AtomicReplaceFile:
    """
    A context manager for atomically replacing a target file.

    Uses :class:`tempfile.NamedTemporaryFile` to construct a temporary file in
    the same directory as the target file. The associated file-like object is
    returned as the context manager's variable; you should write the content
    you wish to this object.

    When the context manager exits, if no exception has occurred, the temporary
    file will be renamed over the target file atomically (after copying
    permissions from the target file). If an exception occurs during the
    context manager's block, the temporary file will be deleted leaving the
    original target file unaffected and the exception will be re-raised.

    :param pathlib.Path path:
        The full path and filename of the target file. This is expected to be
        an absolute path.

    :param str encoding:
        If ``None`` (the default), the temporary file will be opened in binary
        mode. Otherwise, this specifies the encoding to use with text mode.
    """
    def __init__(self, path: t.Union[str, Path], encoding: str = None):
        if isinstance(path, str):
            path = Path(path)
        self._path = path
        self._tempfile = tempfile.NamedTemporaryFile(
            mode='wb' if encoding is None else 'w',
            dir=str(self._path.parent), encoding=encoding, delete=False)
        self._withfile = None

    def __enter__(self):
        self._withfile = self._tempfile.__enter__()
        return self._withfile

    def __exit__(self, exc_type, exc_value, exc_tb):
        os.fchmod(self._withfile.file.fileno(), self._path.stat().st_mode)
        result = self._tempfile.__exit__(exc_type, exc_value, exc_tb)
        if exc_type is None:
            os.rename(self._withfile.name, str(self._path))
        else:
            os.unlink(self._withfile.name)
        return result


if __name__ == '__main__':
    sys.exit(main())
