#!/usr/bin/python3

import re
import sys
import argparse
from pathlib import Path


ABSTRACT_CLASSES = {
    'Device',
    'GPIODevice',
    'SmoothedInputDevice',
    'AnalogInputDevice',
    'MCP3xxx',
    'MCP3xx2',
    'MCP30xx',
    'MCP32xx',
    'MCP33xx',
    'SPIDevice',
    'CompositeDevice',
    'CompositeOutputDevice',
    'LEDCollection',
    'InternalDevice',
}

OMIT_CLASSES = {
    'object',
    'GPIOBase',
    'GPIOMeta',
    'frozendict',
    'WeakMethod',
    '_EnergenieMaster',
}


def main(args=None):
    """
    A simple application for generating GPIO Zero's charts. Specify the root
    class to generate with -i (multiple roots can be specified). Specify parts
    of the hierarchy to exclude with -x. Output is in a format suitable for
    feeding to graphviz's dot application.
    """
    if args is None:
        args = sys.argv[1:]
    my_path = Path(__file__).parent
    default_path = my_path / '..' / '..' / 'gpiozero'
    default_path = default_path.resolve()
    parser = argparse.ArgumentParser(description=main.__doc__)
    parser.add_argument('-p', '--path', action='append', metavar='PATH',
                        default=[], help=
                        "search under PATH for Python source files; can be "
                        "specified multiple times, defaults to %s" % default_path)
    parser.add_argument('-i', '--include', action='append', metavar='CLASS',
                        default=[], help=
                        "only include classes which have BASE somewhere in "
                        "their ancestry; can be specified multiple times")
    parser.add_argument('-x', '--exclude', action='append', metavar='CLASS',
                        default=[], help=
                        "exclude any classes which have BASE somewhere in "
                        "their ancestry; can be specified multiple times")
    parser.add_argument('-o', '--omit', action='append', metavar='CLASS',
                        default=[], help="omit the specified class, but not "
                        "its descendents from the chart; can be specified "
                        "multiple times")
    parser.add_argument('output', nargs='?', type=argparse.FileType('w'),
                        default=sys.stdout, help=
                        "the file to write the output to; defaults to stdout")
    args = parser.parse_args(args)
    if not args.path:
        args.path = [str(default_path)]

    m = make_class_map(set(args.path), OMIT_CLASSES | set(args.omit))
    if args.include or args.exclude:
        m = filter_map(m, include_roots=set(args.include), exclude_roots=set(args.exclude))
    args.output.write(render_map(m, ABSTRACT_CLASSES))


def make_class_map(search_paths, omit):
    """
    Find all Python source files under *search_paths*, extract (via a crude
    regex) all class definitions and return a mapping of class-name to the list
    of base classes.

    All classes listed in *omit* will be excluded from the result, but not
    their descendents (useful for excluding "object" etc.)
    """
    def find_classes():
        class_re = re.compile(
            r'^class\s+(?P<name>\w+)\s*(?:\((?P<bases>.*)\))?:', re.MULTILINE)
        for path in search_paths:
            for py_file in Path(path).rglob('*.py'):
                with py_file.open() as f:
                    for match in class_re.finditer(f.read()):
                        if match.group('name') not in omit:
                            yield match.group('name'), {
                                base.strip()
                                for base in (match.group('bases') or '').split(',')
                                if base.strip() not in omit
                            }
    return {
        name: bases
        for name, bases in find_classes()
    }


def filter_map(class_map, include_roots, exclude_roots):
    """
    Returns *class_map* (which is a mapping such as that returned by
    :func:`make_class_map`), with only those classes which have at least one
    of the *include_roots* in their ancestry, and none of the *exclude_roots*.
    """
    def has_parent(cls, parent):
        return cls == parent or any(
            has_parent(base, parent) for base in class_map.get(cls, ()))

    filtered = {
        name: bases
        for name, bases in class_map.items()
        if (not include_roots or any(has_parent(name, root) for root in include_roots))
        and not any(has_parent(name, root) for root in exclude_roots)
    }
    pure_bases = {
        base for name, bases in filtered.items() for base in bases
    } - set(filtered)
    # Make a second pass to fill in missing links between classes that are
    # only included as bases of other classes
    for base in pure_bases:
        filtered[base] = pure_bases & class_map[base]
    return filtered


def render_map(class_map, abstract):
    """
    Renders *class_map* (which is a mapping such as that returned by
    :func:`make_class_map`) to graphviz's dot language.

    The *abstract* sequence determines which classes will be rendered lighter
    to indicate their abstract nature. All classes with names ending "Mixin"
    will be implicitly rendered in a different style.
    """
    def all_names(class_map):
        for name, bases in class_map.items():
            yield name
            for base in bases:
                yield base

    template = """\
digraph classes {{
    graph [rankdir=RL];
    node [shape=rect, style=filled, fontname=Sans, fontsize=10];
    edge [];

    /* Mixin classes */
    node [color="#c69ee0", fontcolor="#000000"]

    {mixin_nodes}

    /* Abstract classes */
    node [color="#9ec6e0", fontcolor="#000000"]

    {abstract_nodes}

    /* Concrete classes */
    node [color="#2980b9", fontcolor="#ffffff"];

    {edges}
}}
"""

    return template.format(
        mixin_nodes='\n    '.join(
            '{name};'.format(name=name)
            for name in sorted(set(all_names(class_map)))
            if name.endswith('Mixin')
        ),
        abstract_nodes='\n    '.join(
            '{name};'.format(name=name)
            for name in sorted(abstract & set(all_names(class_map)))
        ),
        edges='\n    '.join(
            '{name}->{base};'.format(name=name, base=base)
            for name, bases in sorted(class_map.items())
            for base in sorted(bases)
        ),
    )


if __name__ == '__main__':
    main()
