'''Utility classes and functions for getmail.
'''

__all__ = [
    'address_no_brackets',
    'change_usergroup',
    'change_uidgid',
    'check_ssl_key_and_cert',
    'deliver_maildir',
    'eval_bool',
    'expand_user_vars',
    'is_maildir',
    'localhostname',
    'lock_file',
    'logfile',
    'mbox_from_escape',
    'safe_open',
    'unlock_file',
    'gid_of_uid',
    'uid_of_user',
    'updatefile',
]


import os
import os.path
import socket
import signal
import stat
import time
import glob

import fcntl
import pwd
import grp

from getmailcore.exceptions import *

logtimeformat = '%Y-%m-%d %H:%M:%S'
_bool_values = {
    'true'  : True,
    'yes'   : True,
    'on'    : True,
    '1'     : True,
    'false' : False,
    'no'    : False,
    'off'   : False,
    '0'     : False
}

#######################################
def lock_file(file, locktype):
    '''Do file locking.'''
    assert locktype in ('lockf', 'flock'), 'unknown lock type %s' % locktype
    if locktype == 'lockf':
        fcntl.lockf(file, fcntl.LOCK_EX)
    elif locktype == 'flock':
        fcntl.flock(file, fcntl.LOCK_EX)

#######################################
def unlock_file(file, locktype):
    '''Do file unlocking.'''
    assert locktype in ('lockf', 'flock'), 'unknown lock type %s' % locktype
    if locktype == 'lockf':
        fcntl.lockf(file, fcntl.LOCK_UN)
    elif locktype == 'flock':
        fcntl.flock(file, fcntl.LOCK_UN)

#######################################
def safe_open(path, mode, permissions=0600):
    '''Open a file path safely.
    '''
    if os.name != 'posix':
        return open(path, mode)
    try:
        fd = os.open(path, os.O_RDWR | os.O_CREAT | os.O_EXCL, permissions)
        file = os.fdopen(fd, mode)
    except OSError, o:
        raise getmailDeliveryError('failure opening %s (%s)' % (path, o))
    return file

#######################################
class updatefile(object):
    '''A class for atomically updating files.

    A new, temporary file is created when this class is instantiated. When the
    object's close() method is called, the file is synced to disk and atomically
    renamed to replace the original file.  close() is automatically called when
    the object is deleted.
    '''
    def __init__(self, filename):
        self.closed = False
        self.filename = filename
        self.tmpname = filename + '.tmp.%d' % os.getpid()
	# If the target is a symlink, the rename-on-close semantics of this
	# class would break the symlink, replacing it with the new file.
	# Instead, follow the symlink here, and replace the target file on
	# close.
	while os.path.islink(filename):
	    filename = os.path.join(os.path.dirname(filename),
	                            os.readlink(filename))
        try:
            f = safe_open(self.tmpname, 'wb')
        except IOError, (code, msg):
            raise IOError('%s, opening output file "%s"' % (msg, self.tmpname))
        self.file = f
        self.write = f.write
        self.flush = f.flush

    def __del__(self):
        self.close()

    def abort(self):
        try:
            if hasattr(self, 'file'):
                self.file.close()
        except IOError:
            pass
        self.closed = True

    def close(self):
        if self.closed or not hasattr(self, 'file'):
            return
        self.file.flush()
        self.file.close()
        os.rename(self.tmpname, self.filename)
        self.closed = True

#######################################
class logfile(object):
    '''A class for locking and appending timestamped data lines to a log file.
    '''
    def __init__(self, filename):
        self.closed = False
        self.filename = filename
        try:
            self.file = open(expand_user_vars(self.filename), 'ab')
        except IOError, (code, msg):
            raise IOError('%s, opening file "%s"' % (msg, self.filename))

    def __del__(self):
        self.close()

    def __str__(self):
        return 'logfile(filename="%s")' % self.filename

    def close(self):
        if self.closed:
            return
        self.file.flush()
        self.file.close()
        self.closed = True

    def write(self, s):
        try:
            lock_file(self.file, 'flock')
            # Seek to end
            self.file.seek(0, 2)
            self.file.write(time.strftime(logtimeformat, time.localtime())
                + ' ' + s.rstrip() + os.linesep)
            self.file.flush()
        finally:
            unlock_file(self.file, 'flock')

#######################################
def format_params(d, maskitems=('password', ), skipitems=()):
    '''Take a dictionary of parameters and return a string summary.
    '''
    s = ''
    keys = d.keys()
    keys.sort()
    for key in keys:
        if key in skipitems:
            continue
        if s:
            s += ','
        if key in maskitems:
            s += '%s=*' % key
        else:
            s += '%s="%s"' % (key, d[key])
    return s

###################################
def alarm_handler(*unused):
    '''Handle an alarm during maildir delivery.

    Should never happen.
    '''
    raise getmailDeliveryError('Delivery timeout')

#######################################
def is_maildir(d):
    '''Verify a path is a maildir.
    '''
    dir_parent = os.path.dirname(d.endswith('/') and d[:-1] or d)
    if not os.access(dir_parent, os.X_OK):
        raise getmailConfigurationError(
            'cannot read contents of parent directory of %s '
            '- check permissions and ownership' % d
        )
    if not os.path.isdir(d):
        return False
    if not os.access(d, os.X_OK):
        raise getmailConfigurationError(
            'cannot read contents of directory %s '
            '- check permissions and ownership' % d
        )
    for sub in ('tmp', 'cur', 'new'):
        subdir = os.path.join(d, sub)
        if not os.path.isdir(subdir):
            return False
        if not os.access(subdir, os.W_OK):
            raise getmailConfigurationError(
                'cannot write to maildir %s '
                '- check permissions and ownership' % d
            )
    return True

#######################################
def deliver_maildir(maildirpath, data, hostname, dcount=None, filemode=0600):
    '''Reliably deliver a mail message into a Maildir.  Uses Dan Bernstein's
    documented rules for maildir delivery, and the updated naming convention
    for new files (modern delivery identifiers).  See
    http://cr.yp.to/proto/maildir.html and
    http://qmail.org/man/man5/maildir.html for details.
    '''
    if not is_maildir(maildirpath):
        raise getmailDeliveryError('not a Maildir (%s)' % maildirpath)

    # Set a 24-hour alarm for this delivery
    signal.signal(signal.SIGALRM, alarm_handler)
    signal.alarm(24 * 60 * 60)

    info = {
        'deliverycount' : dcount,
        'hostname' : hostname.split('.')[0].replace('/', '\\057').replace(
            ':', '\\072'),
        'pid' : os.getpid(),
    }
    dir_tmp = os.path.join(maildirpath, 'tmp')
    dir_new = os.path.join(maildirpath, 'new')

    for unused in range(3):
        t = time.time()
        info['secs'] = int(t)
        info['usecs'] = int((t - int(t)) * 1000000)
        info['unique'] = 'M%(usecs)dP%(pid)s' % info
        if info['deliverycount'] is not None:
            info['unique'] += 'Q%(deliverycount)s' % info
        try:
            info['unique'] += 'R%s' % ''.join(['%02x' % ord(char)
                for char in open('/dev/urandom', 'rb').read(8)])
        except StandardError:
            pass

        filename = '%(secs)s.%(unique)s.%(hostname)s' % info
        fname_tmp = os.path.join(dir_tmp, filename)
        fname_new = os.path.join(dir_new, filename)

        # File must not already exist
        if os.path.exists(fname_tmp):
            # djb says sleep two seconds and try again
            time.sleep(2)
            continue

        # Be generous and check cur/file[:...] just in case some other, dumber
        # MDA is in use.  We wouldn't want them to clobber us and have the user
        # blame us for their bugs.
        curpat = os.path.join(maildirpath, 'cur', filename) + ':*'
        collision = glob.glob(curpat)
        if collision:
            # There is a message in maildir/cur/ which could be clobbered by
            # a dumb MUA, and which shouldn't be there.  Abort.
            raise getmailDeliveryError('collision with %s' % collision)

        # Found an unused filename
        break
    else:
        signal.alarm(0)
        raise getmailDeliveryError('failed to allocate file in maildir')

    # Get user & group of maildir
    s_maildir = os.stat(maildirpath)

    # Open file to write
    try:
        f = safe_open(fname_tmp, 'wb', filemode)
        f.write(data)
        f.flush()
        os.fsync(f.fileno())
        f.close()

    except IOError, o:
        signal.alarm(0)
        raise getmailDeliveryError('failure writing file %s (%s)'
            % (fname_tmp, o))

    # Move message file from Maildir/tmp to Maildir/new
    try:
        os.link(fname_tmp, fname_new)
        os.unlink(fname_tmp)

    except OSError:
        signal.alarm(0)
        try:
            os.unlink(fname_tmp)
        except KeyboardInterrupt:
            raise
        except StandardError:
            pass
        raise getmailDeliveryError('failure renaming "%s" to "%s"'
            % (fname_tmp, fname_new))

    # Delivery done

    # Cancel alarm
    signal.alarm(0)
    signal.signal(signal.SIGALRM, signal.SIG_DFL)

    return filename

#######################################
def mbox_from_escape(s):
    '''Escape spaces, tabs, and newlines in the envelope sender address.'''
    return ''.join([(c in (' ', '\t', '\n')) and '-' or c for c in s]) or '<>'

#######################################
def address_no_brackets(addr):
    '''Strip surrounding <> on an email address, if present.'''
    if addr.startswith('<') and addr.endswith('>'):
        return addr[1:-1]
    else:
        return addr

#######################################
def eval_bool(s):
    '''Handle boolean values intelligently.
    '''
    try:
        return _bool_values[str(s).lower()]
    except KeyError:
        raise getmailConfigurationError('boolean parameter requires value'
            ' to be one of true or false, not "%s"' % s)

#######################################
def gid_of_uid(uid):
    try:
        return pwd.getpwuid(uid).pw_gid
    except KeyError, o:
        raise getmailConfigurationError('no such specified uid (%s)' % o)

#######################################
def uid_of_user(user):
    try:
        return pwd.getpwnam(user).pw_uid
    except KeyError, o:
        raise getmailConfigurationError('no such specified user (%s)' % o)

#######################################
def change_usergroup(logger=None, user=None, _group=None):
    '''
    Change the current effective GID and UID to those specified by user and
    _group.
    '''
    uid = None
    gid = None
    if _group:
        if logger:
            logger.debug('Getting GID for specified group %s\n' % _group)
        try:
            gid = grp.getgrnam(_group).gr_gid
        except KeyError, o:
            raise getmailConfigurationError('no such specified group (%s)' % o)
    if user:
        if logger:
            logger.debug('Getting UID for specified user %s\n' % user)
        uid = uid_of_user(user)

    change_uidgid(logger, uid, gid)

#######################################
def change_uidgid(logger=None, uid=None, gid=None):
    '''
    Change the current effective GID and UID to those specified by uid
    and gid.
    '''
    try:
        if gid:
            if os.getegid() != gid:
                if logger:
                    logger.debug('Setting egid to %d\n' % gid)
                os.setregid(gid, gid)
        if uid:
            if os.geteuid() != uid:
                if logger:
                    logger.debug('Setting euid to %d\n' % uid)
                os.setreuid(uid, uid)
    except OSError, o:
        raise getmailDeliveryError('change UID/GID to %s/%s failed (%s)'
            % (uid, gid, o))

#######################################
def format_header(name, line):
    '''Take a long line and return rfc822-style multiline header.
    '''
    header = ''
    line = (name.strip() + ': '
        + ' '.join([part.strip() for part in line.splitlines()]))
    # Split into lines of maximum 78 characters long plus newline, if
    # possible.  A long line may result if no space characters are present.
    while line and len(line) > 78:
        i = line.rfind(' ', 0, 78)
        if i == -1:
            # No space in first 78 characters, try a long line
            i = line.rfind(' ')
            if i == -1:
                # No space at all
                break
        if header:
            header += os.linesep + '  '
        header += line[:i]
        line = line[i:].lstrip()
    if header:
        header += os.linesep + '  '
    if line:
        header += line.strip() + os.linesep
    return header

#######################################
def expand_user_vars(s):
    '''Return a string expanded for both leading "~/" or "~username/" and
    environment variables in the form "$varname" or "${varname}".
    '''
    return os.path.expanduser(os.path.expandvars(s))

#######################################
def localhostname():
    '''Return a name for localhost which is (hopefully) the "correct" FQDN.
    '''
    n = socket.gethostname()
    if '.' in n:
        return n
    return socket.getfqdn()

#######################################
def check_ssl_key_and_cert(conf):
    keyfile = conf['keyfile']
    if keyfile is not None:
        keyfile = expand_user_vars(keyfile)
    certfile = conf['certfile']
    if certfile is not None:
        certfile = expand_user_vars(certfile)
    if keyfile and not os.path.isfile(keyfile):
        raise getmailConfigurationError('optional keyfile must be'
            ' path to a valid file')
    if certfile and not os.path.isfile(certfile):
        raise getmailConfigurationError('optional certfile must be'
            ' path to a valid file')
    if (keyfile is None) ^ (certfile is None):
        raise getmailConfigurationError('optional certfile and keyfile'
            ' must be supplied together')
    return keyfile, certfile
