#!/usr/bin/python3
# -*- python -*-

# Copyright 2014..2018, W. Martin Borgert <debacle@debian.org>
# License: AGPL-3+

# Python standard modules
import argparse
import collections
import configparser
import email.mime.text
import email.utils
import hashlib
import html
import os
import smtplib
import socket
import subprocess
import sys
import textwrap

# additional modules
import apt
import prettytable
import sleekxmpp

longname = "Pain in the APT"
shortname = "painintheapt"
version = "0.20180212"

columns = ["Name", "Installed", "Candidate"]
Package = collections.namedtuple('Package', " ".join(columns).lower())


def getargs():
    ap = argparse.ArgumentParser(
        description='Pester people about available package updates'
        + ' by email or jabber.',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    ap.add_argument('-c', '--configfile', default='/etc/%s.conf' % shortname,
                    help='configuration file')
    ap.add_argument('-d', '--debug', default=False, action='store_true',
                    help='print debug output to stderr')
    ap.add_argument('-f', '--force', default=False, action='store_true',
                    help='send message, even if updates did not change')
    ap.add_argument('-s', '--stampfile', help='stamp file',
                    default='/var/lib/%s/stamp' % shortname)
    ap.add_argument('-v', '--version', action='version',
                    version='%(prog)s ' + version)
    return ap.parse_args()


def update():
    """Create the APT cache and update it.

    Return the cache and a list of updates.
    """
    updates = []
    cache = apt.Cache()
    cache.update()
    cache.open()
    cache.upgrade(dist_upgrade=True)
    changes = cache.get_changes()
    for c in changes:
        name = c._pkg.name
        pkg = cache[name]
        installed = pkg.installed.version if pkg.installed else "-"
        candidate = pkg.candidate.version if pkg.candidate else "-"
        updates.append(Package(name, installed, candidate))
    return cache, updates


def wrap(text, maxwid):
    """Fill paragraph."""
    return "\n".join(textwrap.wrap(text, maxwid))


_changes = None


def get_changelogs(cache, send_changes):
    """Download changelogs. Beware: This is very slow.

    Identical changelogs for different binary packages are combined.
    """
    global _changes
    if not send_changes:
        return ""
    if _changes:
        return _changes
    changelogs = collections.defaultdict(list)
    changes = cache.get_changes()
    for c in changes:
        name = c._pkg.name
        changelog = cache[name].get_changelog().strip()
        changelogs[changelog].append(name)
    # now do some very fancy formatting
    maxwid = 79
    _changes = ("\n" + "-" * maxwid + "\n").join(sorted(
        [wrap(', '.join(sorted(names)), maxwid) + ":\n\n" + changelog
         for changelog, names in changelogs.items()]))
    return _changes


def maketable(lst):
    """Create a pretty table of package updates."""
    table = prettytable.PrettyTable(columns)
    table.sortby = columns[0]
    table.align = 'l'
    maxwid = 23
    for element in lst:
        table.add_row([wrap(element.name, maxwid),
                       wrap(element.installed, maxwid),
                       wrap(element.candidate, maxwid)])
    return table.get_string()


class JabberBot(sleekxmpp.ClientXMPP):
    def __init__(self, jid, password, to, room, pubsub_service, pubsub_node,
                 nick, subject, message):
        sleekxmpp.ClientXMPP.__init__(self, jid, password)
        self.to = to
        self.room = room
        self.pubsub_service = pubsub_service
        self.pubsub_node = pubsub_node
        self.nick = nick
        self.add_event_handler("session_start", self.start)
        self.subject = subject
        self.message = message

    def start(self, event):
        self.getRoster()
        self.sendPresence()
        for to in self.to:
            self.send_message(
                mto=to, msubject=self.subject, mbody=self.message)
        if self.room:
            self.plugin['xep_0045'].joinMUC(self.room, self.nick, wait=True)
            self.send_message(mto=self.room, mbody=self.message,
                              mtype='groupchat')
        if self.pubsub_service and self.pubsub_node:
            payload = '<entry xmlns="http://www.w3.org/2005/Atom"><title>' \
                + html.escape(self.subject) \
                + '</title><content type="xhtml">' \
                + '<pre xmlns="http://www.w3.org/1999/xhtml">' \
                + html.escape(self.message) \
                + '</pre></content></entry>'
            self['xep_0060'].publish(
                self.pubsub_service, self.pubsub_node,
                payload=sleekxmpp.xmlstream.ET.fromstring(payload))
        try:
            self.disconnect(wait=True)
        except TypeError:      # older SleekXMPP doesn't have "wait"
            import time
            time.sleep(10)
            self.disconnect()


def read_password(config, config_dir):
    password_file = config.get("password_file", "").strip()
    if len(password_file):
        filename = os.path.join(config_dir, password_file)
        with open(filename) as f:
            return f.read().strip()

    print("password deprecated, use password_file instead", file=sys.stderr)
    return config.get("password", "")


def sendxmpp(config, config_dir, table, count, host, debug, changes):
    """Send message to a jabber conference room."""
    jid = config.get("jid", "")
    password = read_password(config, config_dir)
    to = config.get("to", "").split(",")
    room = config.get("room")
    pubsub_service = config.get("pubsub_service", "").strip()
    pubsub_node = config.get("pubsub_node", "").strip()
    subject = '%d package update(s) for %s' % (count, host)
    xmpp = JabberBot(
        jid, password, to, room, pubsub_service, pubsub_node,
        longname, subject, "\n\n".join([table, changes]).strip())
    xmpp.register_plugin('xep_0030')        # service discovery
    if room:
        xmpp.register_plugin('xep_0045')    # multi-user chat
    if pubsub_service and pubsub_node:
        xmpp.register_plugin('xep_0060')    # pubsub
    xmpp.register_plugin('xep_0199')        # XMPP ping

    if xmpp.connect():
        xmpp.process(threaded=False)
    else:
        raise("XMPP connect() failed")


def sendsmtp(config, config_dir, table, count, host, debug, changes):
    """Send email by SMTP to whomsoever it may concern."""
    server = config.get("server", "localhost")
    port = config.getint("port", 25)
    username = config.get("username", "")
    password = read_password(config, config_dir)
    from_ = config.get("from", username)
    to = config.get("to", username)
    cc = config.get("cc", "")

    msg = email.mime.text.MIMEText(
        "\n\n".join([table, changes]).strip(), 'plain', 'utf-8')
    msg['From'] = from_
    msg['To'] = to
    msg['Subject'] = '%d package update(s) for %s' % (count, host)
    msg['X-Mailer'] = longname

    if cc:
        msg['Cc'] = cc

    s = smtplib.SMTP(host=server, port=port)
    if debug:
        s.set_debuglevel(True)
    s.starttls()
    s.ehlo_or_helo_if_needed()
    if username or password:
        s.login(username, password)
    recipients = [r[1] for r in email.utils.getaddresses([to + "," + cc])]
    s.sendmail(from_, list(set(recipients)), msg.as_string())
    s.quit()


def sendmailx(config, config_dir, table, count, host, debug, changes):
    """Send email by mailx to whomsoever it may concern."""
    cmd = ["/usr/bin/mailx",
           "-r", config.get("from", "root"),
           "-s", '%d package update(s) for %s' % (count, host),
           "-a", "X-Mailer: " + longname]
    cc = config.get("cc", "")
    if cc:
        cmd += ["-c", cc]
    # this is taken from apticron
    if os.path.realpath("/usr/bin/mailx") == "/usr/bin/heirloom-mailx":
        cmd += ["-S", "ttycharset=utf-8"]
    else:
        cmd += ["-a", "MIME-Version: 1.0",
                "-a", "Content-type: text/plain; charset=UTF-8",
                "-a", "Content-transfer-encoding: 8bit"]
    to = config.get("to", "root")
    mailx = subprocess.Popen(cmd + [to], stdin=subprocess.PIPE)
    mailx.stdin.write("\n\n".join([table, changes]).strip())
    mailx.stdin.close()
    mailx.wait()


def has_changed(configfile, table, stampfile):
    change = False
    hashsum = hashlib.sha1()
    for line in open(configfile):
        hashsum.update(line.encode("utf-8"))
    hashsum.update(table.encode("utf-8"))
    newhash = hashsum.hexdigest()
    try:
        with open(stampfile) as f:
            oldhash = f.readline().strip()
    except Exception as err:
        oldhash = "invalid"
    if oldhash != newhash:
        change = True
    return change, newhash


class AcquireProgress(apt.progress.text.AcquireProgress):
    def __init__(self, debug):
        super(AcquireProgress, self).__init__(
            outfile=sys.stderr if debug else open("/dev/null", "w"))


if __name__ == '__main__':
    args = getargs()
    config = configparser.ConfigParser()
    config.read(args.configfile)
    config_dir = os.path.dirname(args.configfile)
    cache, updates = update()
    count = len(updates)
    fqdn = socket.getfqdn()
    # workaround for dodgy /etc/hosts
    if fqdn in ['localhost', 'localhost.localdomain']:
        fqdn = socket.gethostname() or fqdn
    ret = 0
    table = maketable(updates) if count else ""

    change, newhash = has_changed(args.configfile, table, args.stampfile)

    for section, function in [("XMPP", sendxmpp),
                              ("SMTP", sendsmtp),
                              ("MAILX", sendmailx)]:
        try:
            if section in config.sections() and (change or args.force):
                send_changes = config[section].getboolean("send_changes", True)
                function(config[section], config_dir, table, count, fqdn,
                         args.debug, get_changelogs(cache, send_changes))
        except Exception as err:
            print(str(err), file=sys.stderr)
            ret = 1

    if change or args.force:
        with open(args.stampfile, "wb") as f:
            f.write(newhash.encode("utf-8"))

    cache.fetch_archives(progress=AcquireProgress(args.debug))

    sys.exit(ret)
