#!/usr/bin/env python
#
# Copyright (c) 2014 Luis R. Rodriguez  <mcgrof@suse.com>
# Copyright (c) 2013 Johannes Berg <johannes@sipsolutions.net>
#
# This file is released under the GPLv2.
#
# Python wrapper for Coccinelle for multithreaded support,
# designed to be used for working on a git tree, and with sensible
# defaults, specifically for kernel developers.

from multiprocessing import Process, cpu_count, Queue
import argparse, subprocess, os, sys, re
import tempfile, shutil
import uuid

class ReqError(Exception):
    pass
class ExecutionError(ReqError):
    def __init__(self, errcode):
        self.error_code = errcode

class Req:
    "To be used for verifying binay package dependencies on Python code"
    def __init__(self, chatty=True):
        self.all_reqs_ok = True
        self.debug = False
        self.chatty = chatty
    def logwrite(self, msg):
        if self.chatty:
            sys.stdout.write(msg)
            sys.stdout.flush()
    def enable_debug(self):
        self.debug = True
    def reqs_match(self):
        if self.all_reqs_ok:
            return True
        self.logwrite("You have unfulfilled binary requirements\n")
        return False
    def req_missing(self, program):
        self.all_reqs_ok = False
        self.logwrite("You need to have installed: %s\n" % program)
    def req_old_program(self, program, version_req):
        self.all_reqs_ok = False
        self.logwrite("You need to have installed: %s >= %s\n" % (program, version_req))
    def which(self, program):
        cmd = ['which', program]
        process = subprocess.Popen(cmd,
                                   stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
                                   close_fds=True, universal_newlines=True)
        stdout = process.communicate()[0]
        process.wait()
        if process.returncode != 0:
            raise ExecutionError(process.returncode)
        return stdout
    def req_exists(self, program):
        cmd = ['which', program]
        process = subprocess.Popen(cmd,
                                   stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
                                   close_fds=True, universal_newlines=True)
        stdout = process.communicate()[0]
        process.wait()
        if process.returncode == 0:
            return True
        return False
    def req_get_prog_version(self, program, version_query, version_pos):
        '''
        Suppose you have a binary that outputs:
        $ spatch --version
        spatch version 1.0.0-rc21 with Python support and with PCRE support

        Every program veries what it wants you to query it for a version string,
        prog_version() is designed so that you pass what the program expects for
        its version query, and the position you expect the version string to be
        on using python list.
        '''
        cmd = [program, version_query]
        process = subprocess.Popen(cmd,
                                   stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
                                   close_fds=True, universal_newlines=True)
        stdout = process.communicate()[0]
        process.wait()
        if process.returncode != 0:
            raise ExecutionError(process.returncode)
        if self.debug:
            sys.stdout.write("Running '%s' got us this break down:\n%s\n" %
                             (
                             ' '.join(cmd),
                             "\n".join(map(str, [[i, x] for i, x in enumerate(stdout.split())])),
                             ))
            sys.stdout.write("You are using for version: %s\n" % stdout.split()[version_pos])
            sys.stdout.write("Specifically your idx, element: %s\n" % ([[i, x] for i, x in enumerate(stdout.split())][version_pos]))
        return stdout.split()[version_pos]

    MAX_RC = 25
    def __compute_rel_weight(self, rel_specs):
        weight = 0
        extra = 0
        sublevel = 0
        relmod = 0

        if self.debug:
            sys.stdout.write("VERSION       = %s\n" % rel_specs['VERSION'])
            sys.stdout.write("PATCHLEVEL    = %s\n" % rel_specs['PATCHLEVEL'])
            sys.stdout.write("SUBLEVEL      = %s\n" % rel_specs['SUBLEVEL'])
            sys.stdout.write("EXTRAVERSION  = %s\n" % rel_specs['EXTRAVERSION'])
            sys.stdout.write("RELMOD_UPDATE = %s\n" % rel_specs['RELMOD_UPDATE'])

        if rel_specs['EXTRAVERSION'] != '':
            if ("." in rel_specs['EXTRAVERSION'] or
                    "rc" in rel_specs['EXTRAVERSION']):
                rc = rel_specs['EXTRAVERSION'].lstrip("-rc")
                if (rc == ""):
                    rc = 0
                else:
                    rc = int(rc) - (Req.MAX_RC + 1)
                extra = int(rc)
            else:
                extra = int(rel_specs['EXTRAVERSION']) + 10

        if rel_specs['SUBLEVEL'] != '':
            sublevel = int(rel_specs['SUBLEVEL'].lstrip(".")) * 20
        else:
            sublevel = 5

        if rel_specs['RELMOD_UPDATE'] != '':
            mod = rel_specs['RELMOD_UPDATE']
            if (mod == ""):
                mod = 0
            else:
                mod = int(mod)
            relmod = int(mod)

        weight = (int(rel_specs['VERSION'])    << 32) + \
                 (int(rel_specs['PATCHLEVEL']) << 16) + \
                 (sublevel   		       << 8 ) + \
                 (extra * 60) + (relmod * 2)

        return weight
    def req_get_rel_spec(self, rel):
        if "rc" in rel:
            m = re.match(r"v*(?P<VERSION>\d+)\.+"
                         "(?P<PATCHLEVEL>\d+)[.]*"
                         "(?P<SUBLEVEL>\d*)"
                         "(?P<EXTRAVERSION>[-rc]+\w*)\-*"
                         "(?P<RELMOD_UPDATE>\d*)[-]*",
                         rel)
        else:
            m = re.match(r"v*(?P<VERSION>\d+)\.+"
                         "(?P<PATCHLEVEL>\d+)[.]*"
                         "(?P<SUBLEVEL>\d*)[.]*"
                         "(?P<EXTRAVERSION>\w*)\-*"
                         "(?P<RELMOD_UPDATE>\d*)[-]*",
                         rel)
        if not m:
            return m
        rel_specs = m.groupdict()
        return rel_specs
    def compute_rel_weight(self, rel):
        rel_specs = self.req_get_rel_spec(rel)
        if not rel_specs:
            return 0
        return self.__compute_rel_weight(rel_specs)
    def linux_version_cmp(self, version_req, version):
        '''
        If the program follows the linux version style scheme you can
        use this to compare versions.
        '''
        weight_has = self.compute_rel_weight(version)
        weight_req = self.compute_rel_weight(version_req)

        if self.debug:
            sys.stdout.write("You have program weight: %s\n" % weight_has)
            sys.stdout.write("Required program weight: %s\n" % weight_req)

        if weight_has < weight_req:
            return -1
        return 0
    def require_version(self, program, version_query, version_req, version_pos, version_cmp):
        '''
        If you have a program version requirement you can specify it here,
        as for the other flags refer to prog_version.
        '''
        if not self.require(program):
            return False
        version = self.req_get_prog_version(program, version_query, version_pos)
        if self.debug:
            sys.stdout.write("Checking release specs and weight: for: %s\n" % program)
            sys.stdout.write("You have version: %s\n" % version)
            sys.stdout.write("Required version: %s\n" % version_req)
        if version_cmp(version_req, version) != 0:
            self.req_old_program(program, version_req)
            return False
        return True
    def require(self, program):
        if self.req_exists(program):
            return True
        self.req_missing(program)
        return False
    def require_hint(self, program, package_hint):
        if self.require(program):
            return True
        sys.stdout.write("Try installing the package: %s\n" % package_hint)
        return False
    def coccinelle(self, version):
        if self.require_version('spatch', '--version', version, 2, self.linux_version_cmp):
            return True
        self.logwrite("Try installing the package: coccinelle\n")
        self.logwrite("If that is too old go grab the code from source:\n\n")
        self.logwrite("git clone https://github.com/coccinelle/coccinelle.git\n\n")
        self.logwrite("To build you will need: ocaml ncurses-devel\n\n")
        self.logwrite("If on SUSE / OpenSUSE you will also need: ocaml-ocamldoc\n\n")
        return False
    def make(self, version):
        return self.require_version('make', '--version', version, 2, self.linux_version_cmp)
    def gcc(self, version):
        return self.require_version('gcc', '--version', version, 3, self.linux_version_cmp)

class GitError(Exception):
    pass
class ExecutionGitError(GitError):
    def __init__(self, errcode):
        self.error_code = errcode

def _check(process):
    if process.returncode != 0:
        raise ExecutionError(process.returncode)

def git_init(tree=None):
    process = subprocess.Popen(['git', 'init'],
                               stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
                               close_fds=True, universal_newlines=True, cwd=tree)
    stdout = process.communicate()[0]
    process.wait()
    _check(process)

def git_rev_parse(tree=None, extra_args=None):
    cmd = ['git', 'rev-parse' ] + extra_args
    process = subprocess.Popen(cmd,
                               stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
                               close_fds=True, universal_newlines=True, cwd=tree)
    stdout = process.communicate()[0]
    process.wait()
    if process.returncode != 0:
        return None
    return stdout.split('\n', 1)[0]

def gitname(path=None):
    work_dir = path
    if not os.path.isdir(path):
        work_dir = os.path.dirname(path)
    process = subprocess.Popen(['git', 'rev-parse', '--show-toplevel', path],
                               stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
                               close_fds=True, universal_newlines=True, cwd=work_dir)
    stdout = process.communicate()[0]
    process.wait()
    if process.returncode != 0:
        return None
    return stdout.split('\n', 1)[0]

def git_add(path, tree=None):
    process = subprocess.Popen(['git', 'add', '--ignore-removal', path],
                               stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
                               close_fds=True, universal_newlines=True, cwd=tree)
    stdout = process.communicate()[0]
    process.wait()
    _check(process)

def git_checkout(tree=None, extra_args=None):
    cmd = ['git', 'checkout' ] + extra_args
    process = subprocess.Popen(cmd,
                               stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
                               close_fds=True, universal_newlines=True, cwd=tree)
    stdout = process.communicate()[0]
    process.wait()
    _check(process)

def git_commit_all(message, tree=None):
    git_add('.', tree=tree)
    process = subprocess.Popen(['git', 'commit', '--allow-empty', '-a', '-m', message],
                               stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
                               close_fds=True, universal_newlines=True, cwd=tree)
    stdout = process.communicate()[0]
    process.wait()
    _check(process)

def git_diff(tree=None, extra_args=None):
    cmd = ['git', 'diff', '--color=always'] + extra_args

    process = subprocess.Popen(cmd,
                               stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
                               close_fds=True, universal_newlines=True, cwd=tree)
    stdout = process.communicate()[0]
    process.wait()
    _check(process)

    return stdout

# simple tempdir wrapper object for 'with' statement
#
# Usage:
# with tempdir.tempdir() as tmpdir:
#     os.chdir(tmpdir)
#     do something
#
class tempdir(object):
    def __init__(self, suffix='', prefix='', dir=None, nodelete=False):
        self.suffix = ''
        self.prefix = ''
        self.dir = dir
        self.nodelete = nodelete

    def __enter__(self):
        self._name = tempfile.mkdtemp(suffix=self.suffix,
                                      prefix=self.prefix,
                                      dir=self.dir)
        return self._name

    def __exit__(self, type, value, traceback):
        if self.nodelete:
            print('not deleting directory %s!' % self._name)
        else:
            shutil.rmtree(self._name)

def apply_patches(args, patch_src, target_dir, logwrite=lambda x:None):
    """
    Given a path of a directories of patches apply the patches
    """
    patches = []
    for root, dirs, files in os.walk(patch_src):
        for f in files:
            if f.endswith('.patch'):
                    patches.append(os.path.join(root, f))
    patches.sort()
    prefix_len = len(patch_src) + 1
    for pfile in patches:
        print_name = pfile[prefix_len:]

        logwrite("Applying patch %s\n" % pfile)

        process = subprocess.Popen(['patch', '-p1'], stdout=subprocess.PIPE,
                                   stderr=subprocess.STDOUT, stdin=subprocess.PIPE,
                                   close_fds=True, universal_newlines=True,
                                   cwd=target_dir)
        output = process.communicate(input=open(pfile, 'r').read())[0]
        output = output.split('\n')
        if output[-1] == '':
            output = output[:-1]
        if process.returncode != 0:
            if not args.verbose:
                logwrite("Failed to apply changes from %s" % print_name)
                for line in output:
                    logwrite('> %s' % line)
            raise Exception('Patch failed')

        # remove orig/rej files that patch sometimes creates
        for root, dirs, files in os.walk(target_dir):
            for f in files:
                if f[-5:] == '.orig' or f[-4:] == '.rej':
                    os.unlink(os.path.join(root, f))
        git_commit_all(tree=target_dir, message="apply patch %s" % (print_name))

class CoccinelleError(Exception):
    pass

class ExecutionErrorThread(CoccinelleError):
    def __init__(self, errcode, fn, cocci_file, threads, t, logwrite, print_name):
        self.error_code = errcode
        logwrite("Failed to apply changes from %s\n" % print_name)

        logwrite("Specific log output from change that failed using %s\n" % print_name)
        tf = open(fn, 'r')
        for line in tf.read():
            logwrite(line)
        tf.close()

        logwrite("Full log using %s\n" % print_name)
        for num in range(threads):
            fn = os.path.join(t, '.tmp_spatch_worker.' + str(num))
            if (not os.path.isfile(fn)):
                continue
            tf = open(fn, 'r')
            for line in tf.read():
                logwrite(line)
            tf.close()
            os.unlink(fn)

class ExecutionErrorCocci(CoccinelleError):
    def __init__(self, errcode, output, cocci_file, logwrite, print_name):
        self.error_code = errcode
        logwrite("Failed to apply changes from %s\n" % print_name)
        logwrite(output)

def spatch(cocci_file, outdir, logwrite, num_jobs, print_name, extra_args=[]):

    req = Req(chatty=True)
    req.coccinelle('1.0.2')

    if not req.reqs_match():
        sys.exit(1)

    num_cpus = cpu_count()
    if num_jobs:
        threads = int(num_jobs)
    else:
        threads = num_cpus

    cmd = ['spatch',
            '--sp-file', cocci_file,
            '--in-place',
            '--recursive-includes',
            '--relax-include-path',
            '--timeout', '120',
            '--dir', outdir ]

    if (threads > 1):
            cmd.extend(['--jobs', str(threads)])

    cmd.extend(extra_args)

    logwrite("%s\n" % " ".join(cmd))

    sprocess = subprocess.Popen(cmd,
                                stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
                                close_fds=True, universal_newlines=True)
    output = sprocess.communicate()[0]
    sprocess.wait()
    if sprocess.returncode != 0:
        raise ExecutionErrorCocci(sprocess.returncode, output, cocci_file, logwrite, print_name)
    return output

def spatch_old(cocci_file, outdir,
               max_threads, thread_id, temp_dir, ret_q, extra_args=[]):
    cmd = ['spatch',
            '--sp-file', cocci_file,
            '--in-place',
            '--recursive-includes',
            '--relax-include-path',
            '--timeout', '120',
            '--dir', outdir ]

    if (max_threads > 1):
        cmd.extend(['-max', str(max_threads), '-index', str(thread_id)])

    cmd.extend(extra_args)

    fn = os.path.join(temp_dir, '.tmp_spatch_worker.' + str(thread_id))
    outfile = open(fn, 'w')
    logwrite("%s\n" % " ".join(cmd))

    sprocess = subprocess.Popen(cmd,
                               stdout=outfile, stderr=subprocess.STDOUT,
                               close_fds=True, universal_newlines=True)
    sprocess.wait()
    outfile.close()
    ret_q.put((sprocess.returncode, fn))

def threaded_spatch(cocci_file, outdir, logwrite, num_jobs,
                    print_name, extra_args=[]):
    num_cpus = cpu_count()
    if num_jobs:
        threads = int(num_jobs)
    else:
        threads = num_cpus
    jobs = list()
    output = ""
    ret_q = Queue()
    with tempdir() as t:
        for num in range(threads):
            p = Process(target=spatch_old, args=(cocci_file, outdir,
                                                 threads, num, t, ret_q,
                                                 extra_args))
            jobs.append(p)
        for p in jobs:
            p.start()

        for num in range(threads):
            ret, fn = ret_q.get()
            if ret != 0:
                raise ExecutionErrorThread(ret, fn, cocci_file, threads, t,
                                           logwrite, print_name)
        for job in jobs:
            p.join()

        for num in range(threads):
            fn = os.path.join(t, '.tmp_spatch_worker.' + str(num))
            tf = open(fn, 'r')
            output = output + tf.read()
            tf.close()
            os.unlink(fn)
        return output

def logwrite(msg):
    sys.stdout.write(msg)
    sys.stdout.flush()

def _main():
    parser = argparse.ArgumentParser(description='Wrapper around Coccinelle spatch ' +
                                     'which infers which options to enable.')
    parser.add_argument('cocci_file', metavar='<SmPL patch>', type=str,
                        help='This is the Coccinelle file you want to use')
    parser.add_argument('target_dir', metavar='<target>', type=str,
                        help='Target directory or file to modify')
    parser.add_argument('-p', '--profile-cocci', const=True, default=False, action="store_const",
                        help='Enable profile, this will pass --profile  to Coccinelle.')
    parser.add_argument('-s', '--show-proof', const=True, default=False, action="store_const",
                        help='Show proof that the provided SmPL patch can replace a respective patch series')
    parser.add_argument('-j', '--jobs', metavar='<jobs>', type=str, default=None,
                        help='Only use the cocci file passed for Coccinelle, don\'t do anything else, ' +
                        'also creates a git repo on the target directory for easy inspection ' +
                        'of changes done by Coccinelle.')
    parser.add_argument('-v', '--verbose', const=True, default=False, action="store_const",
                        help='Enable output from Coccinelle')
    args = parser.parse_args()

    if not os.path.isfile(args.cocci_file):
        return -2
    if not os.path.isfile(args.target_dir) and not os.path.isdir(args.target_dir):
        logwrite("Path (%s) is not a file or directory\n" % (args.target_dir))
        return -2

    current_branch = None
    smpl_branch_name = "pycocci-smpl-" + str(uuid.uuid4())[:8]
    patch_branch_name = "pycocci-patch-" + str(uuid.uuid4())[:8]

    extra_spatch_args = []
    if args.profile_cocci:
        extra_spatch_args.append('--profile')
    jobs = 0
    if args.jobs > 0:
        jobs = args.jobs

    has_spatch_1_0_1 = Req(chatty=False)
    has_spatch_1_0_1.coccinelle('1.0.1')

    has_spatch_1_0_2 = Req(chatty=False)
    has_spatch_1_0_2.coccinelle('1.0.2')

    git_reqs = Req(chatty=False)
    git_reqs.require('git')

    glimpse_index = os.path.abspath(os.path.join(args.target_dir, '.glimpse_index'))
    git_dir = None

    if git_reqs.reqs_match():
        git_dir = gitname(args.target_dir)

    if args.show_proof:
        # As an example if you use --show-proof patches/collateral-evolutions/network/09-threaded-irq.cocci
        # the patches under 09-threaded-irq will be used for the proof.
        patch_src = args.cocci_file.split('/')[-1].split('.cocci')[0]
        dirname = os.path.dirname(args.cocci_file)
        patch_src = os.path.abspath(os.path.join(dirname, patch_src))
        if not os.path.isdir(patch_src):
            logwrite("Path given (%s) must be a directory with patches\n" % (patch_src))
            return -2
        git_reqs = Req(chatty=True)
        git_reqs.require('git')
        if not git_dir:
            if os.path.isfile(args.target_dir):
                logwrite("Path given (%s) is a file, try passing the directory "
                         "(%s) if you are certain you want us to create a git repo to provide a"
                         "a proof there\n" % (args.target_dir, os.path.dirname(args.target_dir)))
                return -2
            logwrite("Path (%s) not part of a git tree, creating one for you...\n" % (args.target_dir))
            git_init(tree=args.target_dir)
            git_commit_all(tree=args.target_dir, message="Initial commit")
        cmd = [ '--abbrev-ref', 'HEAD' ]
        current_branch = git_rev_parse(tree=args.target_dir, extra_args = cmd)
        logwrite("\n")
        logwrite("Current branch: %s\n" % (current_branch))
        logwrite("Patch   branch: %s\n" % (patch_branch_name))
        logwrite("SmPL    branch: %s\n" % (smpl_branch_name))
        logwrite("\n")
        git_checkout(tree=args.target_dir, extra_args = ['-b', smpl_branch_name])
        git_checkout(tree=args.target_dir, extra_args = ['-b', patch_branch_name])

        apply_patches(args, patch_src, args.target_dir, logwrite)

        git_checkout(tree=args.target_dir, extra_args = [smpl_branch_name])

    if os.path.isfile(glimpse_index):
       extra_spatch_args.append('--use-glimpse')
    elif has_spatch_1_0_2.reqs_match():
        if git_dir:
            extra_spatch_args.append('--use-gitgrep')
        else:
            extra_spatch_args.append('--use-coccigrep')
    else:
        extra_spatch_args.append('--use-coccigrep')

    if has_spatch_1_0_2.reqs_match():
        output = spatch(args.cocci_file, args.target_dir, logwrite, jobs,
                        os.path.basename(args.cocci_file),
                        extra_args=extra_spatch_args)
    else:
        output = threaded_spatch(args.cocci_file,
                                 args.target_dir,
                                 logwrite,
                                 jobs,
                                 os.path.basename(args.cocci_file),
                                 extra_args=extra_spatch_args)
    if args.verbose:
        logwrite(output)
    if args.show_proof:
        git_commit_all(tree=args.target_dir, message="Initial commit")
        git_checkout(tree=args.target_dir, extra_args = [current_branch])
        cmd = [ '--stat', patch_branch_name + ".." + smpl_branch_name ]
        diff_stat = git_diff(tree=args.target_dir, extra_args = cmd)
        if len(diff_stat) == 0:
            logwrite('\nSmPL patch fully replaces patch series!')
        else:
            logwrite('\nDifferences found:\n\n')
            logwrite('Change directory to %s and run:\n\n\tgit diff %s..%s\n\n' % (args.target_dir, patch_branch_name, smpl_branch_name))
            logwrite('diffstat of the changes:\n')
            logwrite(diff_stat)
    return 0

if __name__ == '__main__':
    ret = _main()
    if ret:
        sys.exit(ret)
