#!/usr/bin/env python
"""
Generate Python source code: random function calls with random arguments.
Use "python" command line program.

Interresting modules: all modules written in C or having some code written
in C, see Modules/*.c in Python source code.
"""

# Fuzzer options
IGNORE_TIMEOUT = True
IGNORE_CPU = True
SHOW_STDOUT = False
DEBUG = False

from fusil.application import Application
from optparse import OptionGroup
from os.path import exists as path_exists, isabs, sep as path_sep
from fusil.process.watch import WatchProcess
from types import FunctionType, BuiltinFunctionType
from random import choice, randint
from fusil.bytes_generator import BytesGenerator
from fusil.unicode_generator import (
    UnicodeGenerator, IntegerGenerator, IntegerRangeGenerator,
    UnsignedGenerator, UnixPathGenerator,
    LETTERS, DECIMAL_DIGITS, UNICODE_65535, PRINTABLE_ASCII)
from fusil.process.stdout import WatchStdout
from fusil.project_agent import ProjectAgent
from fusil.process.create import CreateProcess
from fusil.write_code import WriteCode
from inspect import ismethoddescriptor, isclass
from ptrace.os_tools import RUNNING_PYTHON3
from ptrace.six import text_type, string_types
from ptrace.six.moves import range as xrange
from sys import executable, version as PYTHON_VERSION
import imp
import pkgutil
import re
import sys
import warnings
if not RUNNING_PYTHON3:
    from sys import getfilesystemencoding

# Constants
TIMEOUT = 30.0
PYTHON = executable
FILENAMES = "/etc/motd,/bin/sh"
PARSE_PROTOTYPE = True
PROTOTYPE_REGEX = re.compile(r"[A-Za-z]+[A-Za-z_0-9]*\(([^)]*)\)", re.MULTILINE)

MODULE_BLACKLIST = set((
    # wait keystroke from the keyboard
    "getpass",
    # execute child processes
    "commands",
    "subprocess",
    # run a webbrowser
    "antigravity",
    # too slow, search for .py in the whole hard drive
    "compileall",
    # nothing to fuzz
    "user",
    "this",
    # module dedicated to tests, not used in real world project, but it's build
    # by default
    "_testcapi",
    # ignore the whole CPython test suite
    "test",
    # read/write from/to arbitrary memory (eg. ctypes.string_at())
    "ctypes", "_ctypes",
    # don't fuzz fusil nor python-ptrace
    "fusil", "ptrace",
))

CTYPES = set((
    # Read/write arbitrary memory
    'PyObj_FromPtr', 'string_at', 'wstring_at',
    'call_function', 'call_cdeclfunction',
    'Py_INCREF', 'Py_DECREF',
    'dlsym', 'dlclose',
    '_string_at_addr', '_wstring_at_addr',

    # _ctypes.dlopen("/bin/sh", -5):
    # Inconsistency detected by ld.so: dl-open.c: 652:
    # _dl_open: Assertion `_dl_debug_initialize (0, args.nsid)->r_state == RT_CONSISTENT' failed!
    # The bug is specific to Python (not reproductible outside Python)
    "dlopen",
))

SOCKET = set((
    # Avoid DNS request (timeout)
    "gethostbyname", "gethostbyname_ex", "gethostbyaddr",
    "getnameinfo", "getaddrinfo",

    # network objects having blocking operations like accept() or recv()
    "socket", "SocketType",
))

POSIX = set((
    # exit python
    "_exit", "abort",
    # read -> timeout
    "read",
    # truncate file, remove directory, remove file
    "ftruncate", "rmdir", "unlink",
    # send a signal to the current process or a process group
    "kill", "killpg",
    # create child process
    "fork", "forkpty", "system", "popen", "popen2", "popen3", "popen4",
    "spawnl", "spawnle", "spawnlp", "spawnlpe", "spawnv", "spawnve", "spawnvp", "spawnvpe",
    "execl", "execle", "execlp", "execlpe", "execv", "execve", "execvp", "execvpe",
    # wait process exit (-> timeout)
    "wait", "wait3", "waitpid",
    # break the terminal?
    "tcsetpgrp",
    # long loop
    "closerange",
))

BUILTINS = set((
    # Create huge integer, very long string or list
    "pow", "round",
))

# Functions and methods blacklist. Format:
#   module name => function and class names
# and
#    module name:class name => method names
BLACKLIST = {
    # Dangerous module: ctypes
    'ctypes': CTYPES,
    '_ctypes': CTYPES,

    # Eat lot of CPU with large arguments
    'itertools': set(("tee",)),
    'math': set(("factorial",)),
    'operator': set((
        # Create huge integer, very long string or list
        "pow", "__pow__",
        "ipow", "__ipow__",
        "mul", "__mul__",
        "repeat", "__repeat__",
    )),

    '__builtin__': BUILTINS,
    'builtins': BUILTINS,

    # Don't raise SystemError
    '_builtin__:set': set(("test_c_api",)),
    'builtins:set': set(("test_c_api",)),

    # Sleep
    'time': set(("sleep",)),
    'select': set(("epoll", "poll", "select")),
    'signal': set(("pause", "alarm", "setitimer", "pthread_kill")),

    '_socket': SOCKET,
    'socket': SOCKET,
    'posix': POSIX,
    'os': POSIX,

    '_fileio:_FileIO': set((
        # timeout
        "read", "readall",
    )),

    # timeout
    'multiprocessing': set(("Pool",)),
    '_multiprocessing:SemLock': set((
        # deadlock
        "acquire",
    )),
    '_multiprocessing:Connection': set((
        # timeout
        "recv", "recv_bytes", "poll",
    )),

    '_tkinter': set((
        # timeout
        'dooneevent',
        # no window -> no event -> loop on select()
        'mainloop',
    )),

    "termios": set((
        # tcflow(0, False) suspend output to stdout
        "tcflow",
    )),

    "dl": set((
        # dl.open("/bin/sh", -5):
        # Inconsistency detected by ld.so: dl-open.c: 652:
        # _dl_open: Assertion `_dl_debug_initialize (0, args.nsid)->r_state == RT_CONSISTENT' failed!
        # The bug is specific to Python (not reproductible outside Python)
        "open",
    )),

    "pydoc": set((
        # listen to a socket and wait for requests
        "serve",
        # avoid false positive with pattern (eg. "memory") on stdout
        "doc", "apropos",
    )),

    # listen to a socket and wait for requests
    "BaseHTTPServer": set(("test",)),
    "CGIHTTPServer": set(("test",)),
    "SimpleHTTPServer": set(("test",)),

    "pprint": set(("_perfcheck",)), # timeout (unlimited loop?)
    "tabnanny": set(("check",)),    # python 2.5.2 implementation is just slow

    # create child process
    "popen2": set(("popen2", "popen3", "popen4", "Popen3", "Popen4")),
    "pty": set(("fork", "spawn")),
    "platform": set(("_syscmd_uname",)),

    # avoid false positive with pattern on stdout
    "logging": set(("warning", "error", "fatal", "critical")),
    "formatter": set(("test",)),

    # Create huge integer, very long string or list
    "fpformat": set(("fix",)),

    # remove directory
    "shutil": set(("copytree", "rmtree")),

    # open a network connection (timeout)
    # FIXME: only blacklist the blocking methods, not the whole class?
    "imaplib": set(("IMAP4", "IMAP4_stream",)),
    "telnetlib": set(("Telnet",)),
    "nntplib": set(("NNTP",)),
    "smtplib": set(("SMTP", "SMTP_SSL")),

    # open a network connection (timeout),
    # the constructor opens directly a connection
    "poplib": set(("POP3", "POP3_SSL",)),
    "ftplib": set(("FTP", "FTP_TLS")),

    # set resource limit, may stop the process:
    # setrlimit(RLIMIT_CPU, (0, 0)) kills the process with a SIGKILL signal
    "resource": set(("setrlimit",)),

    "xmllib": set(("test",)),          # timeout
    "urllib2": set(("randombytes",)),  # unlimited loop

    "py_compile": set(("compile",)),
    "runpy": set(("run_path",)),

    "faulthandler": set((
        '_fatal_error', '_read_null', '_sigabrt', '_sigbus', '_sigfpe', '_sigill', '_sigsegv', '_stack_overflow',
    )),

    # TODO: blacklist distutils/spawn.py (35): spawn
    # TODO: blacklist distutils/spawn.py (121): _spawn_posix
}

if DEBUG:
    NB_CALL = 5
    NB_METHOD = 1
    NB_CLASS = 5
else:
    NB_CALL = 250
    NB_METHOD = 15
    NB_CLASS = 50
MAX_ARG = 6
MAX_VAR_ARG = 5

if RUNNING_PYTHON3:
    SURROGATES = (
        u'"\\uDC80"',
        u'"\\uDC80"',
        u'"\\U0010FFFF"',
        u'"\\udbff\\udfff"',
    )

    BUFFER_OBJECTS = (
        u'bytearray(b"abc\\xe9\\xff")',
        u'memoryview(b"abc\\xe9\\xff")',
        u'memoryview(bytearray(b"abc\\xe9\\xff"))',
    )
else:
    SURROGATES = (
        u'u"\\uDC80"',
        u'u"\\uDC80"',
        u'u"\\U0010FFFF"',
        u'u"\\udbff\\udfff"',
    )

    BUFFER_OBJECTS = (
        u'buffer("abc\\xe9\\xff")',
    )

class ListAllModules:
    def __init__(self, logger, only_c, site_package, blacklist):
        self.logger = logger
        self.only_c = only_c
        self.site_package = site_package
        self.blacklist = blacklist
        self.names = set(sys.builtin_module_names) - set(('__main__',))

    def matchModule(self, is_package, name, filename):
        if filename is not None and not self.filter_filename(filename):
            return False
        if not self.only_c:
            return True
        if is_package:
            return False
        if any(filename.endswith(ext)
               for ext in ('.py', '.pyc', '.pyo')):
            return False
        return True

    def filter_filename(self, filename):
        if not self.site_package:
            if 'site-packages' in filename.split(path_sep):
                return False
        return True

    def search_modules(self, path, prefix):
        if path is not None and not self.site_package:
            path = [name for name in path
                    if self.filter_filename(name)]
            if not path:
                return
        for loader, name, is_package in pkgutil.iter_modules(path, ''):
            if name.endswith('_d'):
                # Ignore Debian debug modules
                # (eg. ignore "_bisect_d", the real module is "_bisect)
                continue
            if name in self.blacklist:
                # Ignore the module and all of its submodules
                continue
            fullname = prefix + name
            # Get module filename / directory name
            try:
                module = imp.find_module(name, path)
            except ImportError:
                self.logger.warning("Failed to import module %s" % name)
                continue
            if module[0]:
                module[0].close()
            filename = module[1]
            if self.matchModule(is_package, name, filename):
                self.names.add(fullname)
            if is_package:
                new_paths = [filename]
                new_prefix = prefix + name + '.'
                self.search_modules(new_paths, new_prefix)

    def search(self):
        self.search_modules(None, '')
        return self.names

class Fuzzer(Application):
    NAME = "python"

    def createFuzzerOptions(self, parser):
        options = OptionGroup(parser, "Python fuzzer")
        options.add_option("--modules",
            help="Tested Python module names separated by commas (default: test all modules)",
            type="str", default="*")
        options.add_option("--blacklist",
            help='Module blacklist separated by commas (eg. "_lsprof,_json")',
            type="str")
        options.add_option("--test-private",
            help="Test private methods (default: skip privates methods)",
            action="store_true")
        options.add_option("--timeout",
            help="Timeout in seconds (default: %.1f)" % TIMEOUT,
            type="float", default=TIMEOUT)
        options.add_option("--filenames",
            help="Names separated by commas of readable files (default: %s)" % FILENAMES,
            type="str", default=FILENAMES)
        options.add_option("--python",
            help="Python executable program path (default: %s)" % PYTHON,
            type="str", default=PYTHON)
        options.add_option("--only-c",
            help="Only search for modules written in C (default: search all module)",
            action="store_true")
        options.add_option("--no-site-packages",
            help="Don't search modules in site-packages directory",
            action="store_true")
        return options

    def setupProject(self):
        project = self.project

        project.error("Use python interpreter: %s" % self.options.python)
        version = ' -- '.join( line.strip() for line in PYTHON_VERSION.splitlines())
        project.error("Python version: %s" % version)
        PythonSource(project, self.options)
        process = PythonProcess(project, [self.options.python, '-u', '<source.py>'], timeout=self.options.timeout)
        options = {'exitcode_score' : 0}
        if IGNORE_TIMEOUT:
            options['timeout_score'] = 0
        watch = WatchProcess(process, **options)
        if watch.cpu and IGNORE_CPU:
            watch.cpu.max_score = 0

        stdout = WatchStdout(process)
        stdout.max_nb_line = None

        # Disable dummy error messages
        stdout.words = {
            'oops': 0.30,
            'bug': 0.30,
            'fatal': 1.0,
            'assert': 1.0,
            'assertion': 1.0,
            'critical': 1.0,
            'panic': 1.0,
            'glibc detected': 1.0,
            'segfault': 1.0,
            'segmentation fault': 1.0,
        }

        # CPython critical messages
        stdout.addRegex("^XXX undetected error", 1.0)
        stdout.addRegex("Fatal Python error", 1.0)
        # Match "Cannot allocate memory"?

        # PyPy messages
        stdout.addRegex("Fatal RPython error", 1.0)

        if SHOW_STDOUT or DEBUG:
            stdout.show_matching = True
            stdout.show_not_matching = True

        # avoid matching on "assert" keyword
        stdout.ignoreRegex(r"ast\.Assert()")

        # PyPy interact prompt
        # avoid false positive on "# assert did not crash"
        stdout.ignoreRegex(r"^And now for something completely different:")

        # Hide Python warnings on import
        warnings.simplefilter('ignore')

ERRBACK_NAME = u'errback'

METHODS_NB_ARG = {
    '__str__': 0,
    '__repr__': 0,
    '__hash__': 0,
    '__reduce__': 0,
    '__delattr__': 1,
    '__getattribute__': 1,
    '__getitem__': 1,
    '__getslice__': 2,
    '__reduce_ex__': (0, 1),
    '__getstate__': 0,
    '__setattr__': 2,
    '__setstate__': 1,
}


class PythonFuzzerError(Exception):
    pass

class PythonSource(ProjectAgent):
    def __init__(self, project, options):
        ProjectAgent.__init__(self, project, "python_source")
        self.options = options
        if self.options.modules != "*":
            self.modules = set()
            for module in self.options.modules.split(","):
                module = module.strip()
                if not len(module):
                    continue
                self.modules.add(module)
        else:
            self.error("Search all Python modules...")
            self.modules = ListAllModules(self, self.options.only_c, not self.options.no_site_packages, MODULE_BLACKLIST).search()
        blacklist = self.options.blacklist
        if blacklist:
            blacklist = set(blacklist.split(","))
            removed =  self.modules & blacklist
            self.error("Blacklist modules: %s" % removed)
            self.modules = list(self.modules - blacklist)
        self.modules = list(self.modules)
        self.modules.sort()
        self.error("Found %s Python modules" % len(self.modules))
        for name in self.modules:
            self.info("Python module: %s" % name)

        self.filenames = self.options.filenames
        if not RUNNING_PYTHON3:
            encoding = getfilesystemencoding()
            self.filenames = text_type(self.filenames, encoding)
        self.filenames = self.filenames.split(u",")
        for filename in self.filenames:
            if not isabs(filename):
                raise ValueError("Filename %r is not an absolute path! Fix the --filenames option" % filename)
            if not path_exists(filename):
                raise ValueError("File doesn't exist: %s! Use different --filenames option" % filename)
        project.error(u"Use filenames: %s" % u', '.join(self.filenames))

    def loadModule(self, module_name):
        self.module_name = module_name
        self.debug("Import %s" % self.module_name)
        self.module = __import__(self.module_name)
        for name in self.module_name.split(".")[1:]:
            self.module = getattr(self.module, name)
        try:
            self.warning("Module filename: %s" % self.module.__file__)
        except AttributeError:
            pass
        self.write = WritePythonCode(self, self.filename, self.module, self.module_name)

    def on_session_start(self):
        self.filename = self.session().createFilename(u'source.py')

        # copy sys.modules
        old_sys_modules = sys.modules.copy()

        while self.modules:
            name = choice(self.modules)
            try:
                self.loadModule(name)
                break
            except BaseException as err:
                self.error("Unable to load module %s: [%s] %s"
                    % (name, err.__class__.__name__, err))
                self.modules.remove(name)
        if not self.modules:
            self.error("There is no more modules!")
            self.send('project_stop')
            return
        self.error("Test module %s" % name)
        self.send('session_rename', name)

        self.write.writeSource()
        self.send('python_source', self.filename)

        # unload new modules
        sys.modules.clear()
        sys.modules.update(old_sys_modules)

class PythonProcess(CreateProcess):
    def on_python_source(self, filename):
        self.cmdline.arguments[-1] = filename
        self.createProcess()

# >'<, >"<, >\<
ESCAPE_CHARARACTERS = u"'" + u'"' + u'\\'

def formatCharacter(char):
    if char in ESCAPE_CHARARACTERS:
        # >\"<
        return u'\\' + char
    code = ord(char)
    if 32 <= code <= 126:
        # >a<
        return char
    elif code <= 255:
        # >\xEF<
        return u'\\x%02X' % code
    elif code <= 65535:
        # >\u0101<
        return u'\\u%04X' % code
    else:
        # >\U00010FA3<
        return u'\\U%08X' % code

def escapeUnicode(text):
    return ''.join( formatCharacter(char) for char in text)

class WritePythonCode(WriteCode):
    def __init__(self, parent, filename, module, module_name):
        WriteCode.__init__(self)
        self.filename = filename
        self.filenames = parent.filenames
        self.options = parent.options
        self.hashable_argument_generators = (
            self.genNone,
            self.genBool,
            self.genSmallUint,
            self.genInt,
            self.genLetterDigit,
            self.genBytes,
            self.genString,
            self.genSurrogates,
            self.genAsciiString,
            self.genUnixPath,
            self.genFloat,
            self.genExistingFilename,
            self.genErrback,
#            self.genOpenFile,
#            self.genException,
        )
        self.simple_argument_generators = self.hashable_argument_generators + (
            self.genBufferObject,
        )
        self.complex_argument_generators = (
            self.genList,
            self.genTuple,
            self.genDict,
        )
        self.smallint_generator = IntegerRangeGenerator(-19, 19)
        self.int_generator = IntegerGenerator(20)
        self.bytes_generator = BytesGenerator(0, 20)
        self.unicode_generator = UnicodeGenerator(1, 20, UNICODE_65535)
        self.ascii_generator = UnicodeGenerator(0, 20, PRINTABLE_ASCII)
        self.unix_path_generator = UnixPathGenerator(100)
        self.letters_generator = UnicodeGenerator(1, 8, LETTERS | DECIMAL_DIGITS)
        self.float_int_generator = IntegerGenerator(3)
        self.float_float_generator = UnsignedGenerator(3)
        self.module = module
        self.module_name = module_name

        self.functions, self.classes = self.getFunctions()
        if not self.functions and not self.classes:
            raise PythonFuzzerError("Module %s has no function and no class!" % self.module_name)

    def writePrint(self, level, arguments):
        if RUNNING_PYTHON3:
            code = u"print (%s, file=stderr)" % arguments
        else:
            code = u"print >>stderr, %s" % arguments
        self.write(level, code)

    def writeSource(self):
        self.createFile(self.filename)
        self.write(0, u"# -*- coding: ASCII -*-")
        self.write(0, u"from gc import collect")
        self.write(0, u"from sys import stderr")
        self.writePrint(0, u'"import %s"' % self.module_name)
        self.write(0, u"import %s" % self.module_name)
        self.emptyLine()
        self.write(0, u"def %s(*args, **kw):" % ERRBACK_NAME)
        self.write(1, u"raise ValueError('error')")
        self.emptyLine()

        self.write(0, "def callMethod(prefix, object, name, *arguments):")
        level = self.addLevel(1)
        self.writeCallMethod()
        self.restoreLevel(level)
        self.emptyLine()

        self.write(0, "def callFunc(prefix, name, *arguments):")
        self.write(1, "return callMethod(prefix, %s, name, *arguments)" % self.module_name)
        self.emptyLine()

        self.writeCode(u'', self.module, self.functions, self.classes, 1, NB_CALL)
        self.close()

    def writeCallMethod(self):
        self.write(0, u'funcname = "%s.%%s()" %% name' % self.module_name)
        self.write(0, u'message = "[%s] %s" % (prefix, funcname)')
        self.writePrint(0, u'message')
        self.write(0, u'func = getattr(object, name)')
        self.write(0, u'try:')
        self.write(1, u'result = func(*arguments)')

        exceptions = u'(Exception, SystemExit, KeyboardInterrupt)'
        if RUNNING_PYTHON3:
            self.write(0, u'except %s as err:' % exceptions)
        else:
            self.write(0, u'except %s, err:' % exceptions)
        self.write(1, u'errmsg = repr(err)')
        if RUNNING_PYTHON3:
            self.write(1, u"errmsg = errmsg.encode('ASCII', 'replace')")
        self.writePrint(1, u'"[%s] %s => %s: %s" % (prefix, funcname, err.__class__.__name__, errmsg)')
        self.write(1, u'result = None')

        self.writePrint(0, u'"[%s] -garbage collector-" % prefix')
        self.write(0, u'collect()   # explicit call to the garbage collector')
        self.write(0, u'return result')

    def getFunctions(self):
        classes = []
        functions = []
        try:
            blacklist = BLACKLIST[self.module_name]
        except KeyError:
            blacklist = set()

        if not hasattr(self.module, "__all__"):
            names = set(dir(self.module))
        else:
            names = set(self.module.__all__)
        names -= set(("__builtins__", "__doc__", "__file__", "__name__"))
        names -= blacklist
        for name in names:
            try:
                attr = getattr(self.module, name)
            except AttributeError:
                # attribute declared in __all__, but no declared?
                continue
            if isinstance(attr, (FunctionType, BuiltinFunctionType)):
                functions.append(name)
            elif isinstance(attr, type) or isclass(attr):
                classes.append(name)
        return functions, classes

    def getMethods(self, object, class_name):
        try:
            key = "%s:%s" % (self.module_name, class_name)
            blacklist = BLACKLIST[key]
        except KeyError:
            blacklist = set()
        methods = []
        for name in dir(object):
            if name in blacklist:
                continue
            if (not self.options.test_private) and name.startswith("__"):
                continue
            attr = getattr(object, name)
            if not ismethoddescriptor(attr):
                continue
            methods.append(name)
        return methods

    def _createArgument(self, generators):
        callback = choice(generators)
        value = callback()
        for item in value:
            if not isinstance(item, text_type):
                raise ValueError("%s returned type %s" % (callback, type(item)))
        return value

    def createArgument(self):
        return self._createArgument(self.simple_argument_generators)

    def createHashableArgument(self):
        return self._createArgument(self.hashable_argument_generators)

    def createComplexArgument(self):
        if randint(0, 9) == 0:
            # 10%
            generators = self.complex_argument_generators
        else:
            # 90%
            generators = self.simple_argument_generators
        return self._createArgument(generators)

    def getNbArg(self, func, func_name, min_arg):
        try:
            # Known method of arguments?
            value = METHODS_NB_ARG[func_name]
            if isinstance(value, tuple):
                min_arg, max_arg = value
            else:
                min_arg = max_arg = value
            return min_arg, max_arg
        except KeyError:
            pass

        if PARSE_PROTOTYPE:
            # Try using the documentation
            args = parseDocumentation(func.__doc__, MAX_VAR_ARG)
            if args:
                return args

        return min_arg, MAX_ARG

    def callFunction(self, prefix, func_index, func_name, func, min_arg):
        min_arg, max_arg = self.getNbArg(func, func_name, min_arg)
        nb_arg = randint(min_arg, max_arg)

        if prefix:
            prefix += str(1+func_index)
            first_line = u'callMethod("%s", obj, "%s"' % (prefix, func_name)
        else:
            prefix = "f%s" % (1+func_index)
            first_line = u'callFunc("%s", "%s"' % (prefix, func_name)
        if nb_arg:
            self.write(0, first_line + u',')
            level = self.addLevel(1)
            last_char = u','
            for index in xrange(nb_arg):
                if index == nb_arg-1:
                    last_char = u')'
                self.writeArgument(1, last_char)
            self.restoreLevel(level)
        else:
            self.write(0, first_line + ')')
        self.emptyLine()

    def writeArgument(self, level, last_char=u','):
        lines = self.createComplexArgument()
        lines[-1] += last_char
        for line in lines:
            self.write(level, line)

    def useClass(self, cls_index, cls, class_name):
        nb_arg = randint(1, MAX_ARG)

        prefix = 'o%s' % (1 + cls_index)
        self.writePrint(0, u'"[%s] Create object %s"' % (prefix, 1 + cls_index))

        obj_name = u'obj'
        self.write(0, u'%s = callFunc("%s", "%s",' % (obj_name, prefix, class_name))

        for index in xrange(nb_arg):
            self.write(2, u'# argument %s/%s' % (1+index, nb_arg))
            self.writeArgument(2)
        self.write(1, u')')

        methods = self.getMethods(cls, class_name)
        if methods:
            self.write(0, u'if %s:' % obj_name)
            level = self.addLevel(1)
            self.writeCode(prefix+'m', cls, methods, tuple(), 0, NB_METHOD)
            self.write(0, u'del %s' % obj_name)
            self.writePrint(0, u'"[%s] -garbage collector -"' % prefix)
            self.write(0, u'collect()   # explicit call to the garbage collector')
            self.restoreLevel(level)
        self.emptyLine()

    def writeCode(self, prefix, object, functions, classes, func_min_arg, nb_call):
        if functions:
            for index in xrange(nb_call):
                func_name = choice(functions)
                func = getattr(object, func_name)
                self.callFunction(prefix, index, func_name, func, func_min_arg)
        if classes:
            self.nb_class = NB_CLASS
            for index in xrange(self.nb_class):
                class_name = choice(classes)
                cls = getattr(object, class_name)
                self.useClass(index, cls, class_name)

    def genNone(self):
        return [u'None']

    def genBool(self):
        if randint(0, 1) == 1:
            return [u'True']
        else:
            return [u'False']

    def genSmallUint(self):
        return [self.smallint_generator.createValue()]

    def genInt(self):
        return [self.int_generator.createValue()]

    def genBytes(self):
        # Bytes string
        bytes = self.bytes_generator.createValue()
        if RUNNING_PYTHON3:
            text = ''.join( u"\\x%02X" % byte for byte in bytes)
            text = 'b"%s"' % text
        else:
            text = u''.join( u"\\x%02X" % ord(byte) for byte in bytes)
            text = u'"%s"' % text
        return [text]

    def genUnixPath(self):
        path = self.unix_path_generator.createValue()
        return [u'"%s"' % path]

    def _genUnicode(self, generator):
        # (Unicode) character string
        text = generator.createValue()
        text = escapeUnicode(text)
        if RUNNING_PYTHON3:
            text = u'"%s"' % text
        else:
            text = u'u"%s"' % text
        return [text]

    def genLetterDigit(self):
        return self._genUnicode(self.letters_generator)

    def genString(self):
        return self._genUnicode(self.unicode_generator)

    def genSurrogates(self):
        text = choice(SURROGATES)
        return [text]

    def genBufferObject(self):
        text = choice(BUFFER_OBJECTS)
        return [text]

    def genAsciiString(self):
        return self._genUnicode(self.ascii_generator)

    def genFloat(self):
        int_part = self.float_int_generator.createValue()
        float_part = self.float_float_generator.createValue()
        return [u"%s.%s" % (int_part, float_part)]

    def genExistingFilename(self):
        filename = choice(self.filenames)
        return [u"'%s'" % filename]

    def genErrback(self):
        return ["%s" % ERRBACK_NAME]

    def genOpenFile(self):
        filename = choice(self.filenames)
        if RUNNING_PYTHON3:
            instr = "open('%s')" % filename
        else:
            instr = "open(u'%s')" % filename
        return [instr]

    def genException(self):
        return ["Exception('pouet')"]

    def _genList(self, open_text, close_text, empty, is_dict=False):
        # 90% of the time generate values of the same type
        same_type = (randint(0, 9) != 0)
        nb_item = randint(0, 9)
        if not nb_item:
            return [empty]
        items = []
        if same_type:
            if is_dict:
                key_callback = choice(self.hashable_argument_generators)
            value_callback = choice(self.simple_argument_generators)
            for index in xrange(nb_item):
                if is_dict:
                    item = self.createDictItem(key_callback, value_callback)
                else:
                    item = value_callback()
                items.append(item)
        else:
            for index in xrange(nb_item):
                if is_dict:
                    item = self.createDictItem()
                else:
                    item = self.createArgument()
                items.append(item)
        lines = []
        for item_index, item_lines in enumerate(items):
            if item_index:
                lines[-1] += u","
                for index, line in enumerate(item_lines):
                    # Add ' ' suffix to all lines
                    item_lines[index] = u' ' + line
            lines.extend(item_lines)
        if nb_item == 1 and empty == u'tuple()':
            lines[-1] += u','
        lines[0] = open_text + lines[0]
        lines[-1] += close_text
        return lines

    def createDictItem(self, key_callback=None, value_callback=None):
        if key_callback:
            key = key_callback()
        else:
            key = self.createHashableArgument()
        if value_callback:
            value = value_callback()
        else:
            value = self.createArgument()
        key[-1] += u": " + value[0]
        key.extend(value[1:])
        return key

    def genList(self):
        return self._genList(u'[', u']', u'[]')

    def genTuple(self):
        return self._genList(u'(', u')', u'tuple()')

    def genDict(self):
        return self._genList(u'{', u'}', u'{}', True)

def parseArguments(arguments, defaults):
    for arg in arguments.split(","):
        arg = arg.strip(" \n[]")
        if not arg:
            continue
        if "=" in arg:
            arg, value = arg.split("=", 1)
            defaults[arg] = value
        yield arg

def parsePrototype(doc):
    r"""
    >>> parsePrototype("test([x])")
    ((), None, ('x',), {})
    >>> parsePrototype('dump(obj, file, protocol=0)')
    (('obj', 'file'), None, ('protocol',), {'protocol': '0'})
    >>> parsePrototype('setitimer(which, seconds[, interval])')
    (('which', 'seconds'), None, ('interval',), {})
    >>> parsePrototype("decompress(string[, wbits[, bufsize]])")
    (('string',), None, ('wbits', 'bufsize'), {})
    >>> parsePrototype("decompress(string,\nwbits)")
    (('string', 'wbits'), None, (), {})
    >>> parsePrototype("get_referents(*objs)")
    ((), '*objs', (), {})
    >>> parsePrototype("nothing")
    """
    if not doc:
        return None
    if not isinstance(doc, string_types):
        return None
    doc = doc.strip()
    match = PROTOTYPE_REGEX.match(doc)
    if not match:
        return None
    arguments = match.group(1)
    if arguments == '...':
        return None
    defaults = {}
    vararg = None
    varkw = tuple()
    if "[" in arguments:
        arguments, varkw = arguments.split("[", 1)
        arguments = tuple(parseArguments(arguments, defaults))
        varkw = tuple(parseArguments(varkw, defaults))
    else:
        arguments = tuple(parseArguments(arguments, defaults))

    # Argument with default value? => varkw
    move = None
    for index in xrange(len(arguments)-1, -1, -1):
        arg = arguments[index]
        if arg not in defaults:
            break
        move = index
    if move is not None:
        varkw = arguments[move:] + varkw
        arguments = arguments[:move]

    if arguments and arguments[-1].startswith("*"):
        vararg = arguments[-1]
        arguments = arguments[:-1]
    return (arguments, vararg, varkw, defaults)

def parseDocumentation(doc, max_var_arg):
    """
    Arguments:
     - doc: documentation string
     - max_var_arg: maximum number of arguments for variable argument,
       eg. test(*args).
    """
    prototype = parsePrototype(doc)
    if not prototype:
        return None

    args, varargs, varkw, defaults = prototype
    min_arg = len(args)
    max_arg = min_arg + len(varkw)
    if varargs:
        max_arg += max_var_arg
    return min_arg, max_arg

if __name__ == "__main__":
    Fuzzer().main()

