# Copyright (C) 2012, Dhiru Kholia <dhiru@openwall.com>
#
# Modified for JtR
#
# Copyright (C) 2011  Jeff Forcier <jeff@bitprophet.org>
#
# This file is part of ssh.
#
# 'ssh' is free software; you can redistribute it and/or modify it under the
# terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 2.1 of the License, or (at your option)
# any later version.
#
# 'ssh' is distrubuted in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with 'ssh'; if not, write to the Free Software Foundation, Inc.,
# 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA.

import traceback
import binascii
import base64
import StringIO
import sys
from hashlib import md5 as MD5

limited = False

try:
    from Crypto.Cipher import DES3, AES
except ImportError:
    print >> sys.stderr, "PyCrypto is missing in your Python installation, %s is operating is limited features mode!" % sys.argv[0]
    limited = True


class BERException (Exception):
    pass


class BER(object):
    """
    Robey's tiny little attempt at a BER decoder.
    """

    def __init__(self, content=''):
        self.content = content
        self.idx = 0

    def __str__(self):
        return self.content

    def __repr__(self):
        return 'BER(\'' + repr(self.content) + '\')'

    def decode(self):
        return self.decode_next()

    def decode_next(self):
        if self.idx >= len(self.content):
            return None
        ident = ord(self.content[self.idx])
        self.idx += 1
        if (ident & 31) == 31:
            # identifier > 30
            ident = 0
            while self.idx < len(self.content):
                t = ord(self.content[self.idx])
                self.idx += 1
                ident = (ident << 7) | (t & 0x7f)
                if not (t & 0x80):
                    break
        if self.idx >= len(self.content):
            return None
        # now fetch length
        size = ord(self.content[self.idx])
        self.idx += 1
        if size & 0x80:
            # more complimicated...
            # FIXME: theoretically should handle indefinite-length (0x80)
            t = size & 0x7f
            if self.idx + t > len(self.content):
                return None
            size = inflate_long(self.content[self.idx : self.idx + t], True)
            self.idx += t
        if self.idx + size > len(self.content):
            # can't fit
            return None
        data = self.content[self.idx : self.idx + size]
        self.idx += size
        # now switch on id
        if ident == 0x30:
            # sequence
            return self.decode_sequence(data)
        elif ident == 2:
            # int
            return inflate_long(data)
        else:
            # 1: boolean (00 false, otherwise true)
            raise BERException('Unknown ber encoding type %d (robey is lazy)' % ident)

    def decode_sequence(data):
        out = []
        b = BER(data)
        while True:
            x = b.decode_next()
            if x is None:
                break
            out.append(x)
        return out
    decode_sequence = staticmethod(decode_sequence)

"""
Exceptions defined by ssh.
"""


class SSHException (Exception):
    """
    Exception raised by failures in SSH2 protocol negotiation or logic errors.
    """
    pass


class AuthenticationException (SSHException):
    """
    Exception raised when authentication failed for some reason.  It may be
    possible to retry with different credentials.  (Other classes specify more
    specific reasons.)

    @since: 1.6
    """
    pass


class PasswordRequiredException (AuthenticationException):
    """
    Exception raised when a password is needed to unlock a private key file.
    """
    pass


class BadAuthenticationType (AuthenticationException):
    """
    Exception raised when an authentication type (like password) is used, but
    the server isn't allowing that type.  (It may only allow public-key, for
    example.)

    @ivar allowed_types: list of allowed authentication types provided by the
        server (possible values are: C{"none"}, C{"password"}, and
        C{"publickey"}).
    @type allowed_types: list

    @since: 1.1
    """
    allowed_types = []

    def __init__(self, explanation, types):
        AuthenticationException.__init__(self, explanation)
        self.allowed_types = types

    def __str__(self):
        return SSHException.__str__(self) + ' (allowed_types=%r)' % self.allowed_types


class PartialAuthentication (AuthenticationException):
    """
    An internal exception thrown in the case of partial authentication.
    """
    allowed_types = []

    def __init__(self, types):
        AuthenticationException.__init__(self, 'partial authentication')
        self.allowed_types = types


class ChannelException (SSHException):
    """
    Exception raised when an attempt to open a new L{Channel} fails.

    @ivar code: the error code returned by the server
    @type code: int

    @since: 1.6
    """
    def __init__(self, code, text):
        SSHException.__init__(self, text)
        self.code = code


class BadHostKeyException (SSHException):
    """
    The host key given by the SSH server did not match what we were expecting.

    @ivar hostname: the hostname of the SSH server
    @type hostname: str
    @ivar key: the host key presented by the server
    @type key: L{PKey}
    @ivar expected_key: the host key expected
    @type expected_key: L{PKey}

    @since: 1.6
    """
    def __init__(self, hostname, got_key, expected_key):
        SSHException.__init__(self, 'Host key for server %s does not match!' % hostname)
        self.hostname = hostname
        self.key = got_key
        self.expected_key = expected_key


"""
Useful functions used by the rest of ssh.
"""


from binascii import hexlify, unhexlify
import sys
import struct


# Change by RogerB - python < 2.3 doesn't have enumerate so we implement it
if sys.version_info < (2, 3):
    class enumerate:
        def __init__(self, sequence):
            self.sequence = sequence

        def __iter__(self):
            count = 0
            for item in self.sequence:
                yield (count, item)
                count += 1


def inflate_long(s, always_positive=False):
    "turns a normalized byte string into a long-int (adapted from Crypto.Util.number)"
    out = 0L
    negative = 0
    if not always_positive and (len(s) > 0) and (ord(s[0]) >= 0x80):
        negative = 1
    if len(s) % 4:
        filler = '\x00'
        if negative:
            filler = '\xff'
        s = filler * (4 - len(s) % 4) + s
    for i in range(0, len(s), 4):
        out = (out << 32) + struct.unpack('>I', s[i:i+4])[0]
    if negative:
        out -= (1L << (8 * len(s)))
    return out

def deflate_long(n, add_sign_padding=True):
    "turns a long-int into a normalized byte string (adapted from Crypto.Util.number)"
    # after much testing, this algorithm was deemed to be the fastest
    s = ''
    n = long(n)
    while (n != 0) and (n != -1):
        s = struct.pack('>I', n & 0xffffffffL) + s
        n = n >> 32
    # strip off leading zeros, FFs
    for i in enumerate(s):
        if (n == 0) and (i[1] != '\000'):
            break
        if (n == -1) and (i[1] != '\xff'):
            break
    else:
        # degenerate case, n was either 0 or -1
        i = (0,)
        if n == 0:
            s = '\000'
        else:
            s = '\xff'
    s = s[i[0]:]
    if add_sign_padding:
        if (n == 0) and (ord(s[0]) >= 0x80):
            s = '\x00' + s
        if (n == -1) and (ord(s[0]) < 0x80):
            s = '\xff' + s
    return s

def format_binary_weird(data):
    out = ''
    for i in enumerate(data):
        out += '%02X' % ord(i[1])
        if i[0] % 2:
            out += ' '
        if i[0] % 16 == 15:
            out += '\n'
    return out

def format_binary(data, prefix=''):
    x = 0
    out = []
    while len(data) > x + 16:
        out.append(format_binary_line(data[x:x+16]))
        x += 16
    if x < len(data):
        out.append(format_binary_line(data[x:]))
    return [prefix + x for x in out]

def format_binary_line(data):
    left = ' '.join(['%02X' % ord(c) for c in data])
    right = ''.join([('.%c..' % c)[(ord(c)+63)//95] for c in data])
    return '%-50s %s' % (left, right)

def hexify(s):
    return hexlify(s).upper()

def unhexify(s):
    return unhexlify(s)

def safe_string(s):
    out = ''
    for c in s:
        if (ord(c) >= 32) and (ord(c) <= 127):
            out += c
        else:
            out += '%%%02X' % ord(c)
    return out

# ''.join([['%%%02X' % ord(c), c][(ord(c) >= 32) and (ord(c) <= 127)] for c in s])

def bit_length(n):
    norm = deflate_long(n, 0)
    hbyte = ord(norm[0])
    if hbyte == 0:
        return 1
    bitlen = len(norm) * 8
    while not (hbyte & 0x80):
        hbyte <<= 1
        bitlen -= 1
    return bitlen

def tb_strings():
    return ''.join(traceback.format_exception(*sys.exc_info())).split('\n')

def generate_key_bytes(hashclass, salt, key, nbytes):
    """
    Given a password, passphrase, or other human-source key, scramble it
    through a secure hash into some keyworthy bytes.  This specific algorithm
    is used for encrypting/decrypting private key files.

    @param hashclass: class from L{Crypto.Hash} that can be used as a secure
        hashing function (like C{MD5} or C{SHA}).
    @type hashclass: L{Crypto.Hash}
    @param salt: data to salt the hash with.
    @type salt: string
    @param key: human-entered password or passphrase.
    @type key: string
    @param nbytes: number of bytes to generate.
    @type nbytes: int
    @return: key data
    @rtype: string
    """
    keydata = ''
    digest = ''
    if len(salt) > 8:
        salt = salt[:8]
    while nbytes > 0:
        hash_obj = hashclass()
        if len(digest) > 0:
            hash_obj.update(digest)
        hash_obj.update(key)
        hash_obj.update(salt)
        digest = hash_obj.digest()
        size = min(nbytes, len(digest))
        keydata += digest[:size]
        nbytes -= size
    return keydata

"""
Common API for all public keys.
"""

import base64
from binascii import hexlify, unhexlify
import os


class PKey (object):
    """
    Base class for public keys.
    """

    # known encryption types for private key files:
    _CIPHER_TABLE = {
        'AES-128-CBC': { 'cipher': AES, 'keysize': 16, 'blocksize': 16, 'mode': AES.MODE_CBC },
        'DES-EDE3-CBC': { 'cipher': DES3, 'keysize': 24, 'blocksize': 8, 'mode': DES3.MODE_CBC },
    }


    def __init__(self, msg=None, data=None):
        """
        Create a new instance of this public key type.  If C{msg} is given,
        the key's public part(s) will be filled in from the message.  If
        C{data} is given, the key's public part(s) will be filled in from
        the string.

        @param msg: an optional SSH L{Message} containing a public key of this
        type.
        @type msg: L{Message}
        @param data: an optional string containing a public key of this type
        @type data: str

        @raise SSHException: if a key cannot be created from the C{data} or
        C{msg} given, or no key was passed in.
        """
        pass

    def __str__(self):
        """
        Return a string of an SSH L{Message} made up of the public part(s) of
        this key.  This string is suitable for passing to L{__init__} to
        re-create the key object later.

        @return: string representation of an SSH key message.
        @rtype: str
        """
        return ''

    def __cmp__(self, other):
        """
        Compare this key to another.  Returns 0 if this key is equivalent to
        the given key, or non-0 if they are different.  Only the public parts
        of the key are compared, so a public key will compare equal to its
        corresponding private key.

        @param other: key to compare to.
        @type other: L{PKey}
        @return: 0 if the two keys are equivalent, non-0 otherwise.
        @rtype: int
        """
        hs = hash(self)
        ho = hash(other)
        if hs != ho:
            return cmp(hs, ho)
        return cmp(str(self), str(other))

    def get_name(self):
        """
        Return the name of this private key implementation.

        @return: name of this private key type, in SSH terminology (for
        example, C{"ssh-rsa"}).
        @rtype: str
        """
        return ''

    def get_bits(self):
        """
        Return the number of significant bits in this key.  This is useful
        for judging the relative security of a key.

        @return: bits in the key.
        @rtype: int
        """
        return 0

    def can_sign(self):
        """
        Return C{True} if this key has the private part necessary for signing
        data.

        @return: C{True} if this is a private key.
        @rtype: bool
        """
        return False

    def get_fingerprint(self):
        """
        Return an MD5 fingerprint of the public part of this key.  Nothing
        secret is revealed.

        @return: a 16-byte string (binary) of the MD5 fingerprint, in SSH
            format.
        @rtype: str
        """
        return MD5.new(str(self)).digest()

    def get_base64(self):
        """
        Return a base64 string containing the public part of this key.  Nothing
        secret is revealed.  This format is compatible with that used to store
        public key files or recognized host keys.

        @return: a base64 string containing the public part of the key.
        @rtype: str
        """
        return base64.encodestring(str(self)).replace('\n', '')

    def sign_ssh_data(self, rng, data):
        """
        Sign a blob of data with this private key, and return a L{Message}
        representing an SSH signature message.

        @param rng: a secure random number generator.
        @type rng: L{Crypto.Util.rng.RandomPool}
        @param data: the data to sign.
        @type data: str
        @return: an SSH signature message.
        @rtype: L{Message}
        """
        return ''

    def verify_ssh_sig(self, data, msg):
        """
        Given a blob of data, and an SSH message representing a signature of
        that data, verify that it was signed with this key.

        @param data: the data that was signed.
        @type data: str
        @param msg: an SSH signature message
        @type msg: L{Message}
        @return: C{True} if the signature verifies correctly; C{False}
            otherwise.
        @rtype: boolean
        """
        return False

    def from_private_key_file(cls, filename, password=None):
        """
        Create a key object by reading a private key file.  If the private
        key is encrypted and C{password} is not C{None}, the given password
        will be used to decrypt the key (otherwise L{PasswordRequiredException}
        is thrown).  Through the magic of python, this factory method will
        exist in all subclasses of PKey (such as L{RSAKey} or L{DSSKey}), but
        is useless on the abstract PKey class.

        @param filename: name of the file to read
        @type filename: str
        @param password: an optional password to use to decrypt the key file,
            if it's encrypted
        @type password: str
        @return: a new key object based on the given private key
        @rtype: L{PKey}

        @raise IOError: if there was an error reading the file
        @raise PasswordRequiredException: if the private key file is
            encrypted, and C{password} is C{None}
        @raise SSHException: if the key file is invalid
        """
        key = cls(filename=filename, password=password)
        return key
    from_private_key_file = classmethod(from_private_key_file)

    def from_private_key(cls, file_obj, password=None):
        """
        Create a key object by reading a private key from a file (or file-like)
        object.  If the private key is encrypted and C{password} is not C{None},
        the given password will be used to decrypt the key (otherwise
        L{PasswordRequiredException} is thrown).

        @param file_obj: the file to read from
        @type file_obj: file
        @param password: an optional password to use to decrypt the key, if it's
            encrypted
        @type password: str
        @return: a new key object based on the given private key
        @rtype: L{PKey}

        @raise IOError: if there was an error reading the key
        @raise PasswordRequiredException: if the private key file is encrypted,
            and C{password} is C{None}
        @raise SSHException: if the key file is invalid
        """
        key = cls(file_obj=file_obj, password=password)
        return key
    from_private_key = classmethod(from_private_key)

    def _read_private_key_file(self, tag, filename, password=None):
        """
        Read an SSH2-format private key file, looking for a string of the type
        C{"BEGIN xxx PRIVATE KEY"} for some C{xxx}, base64-decode the text we
        find, and return it as a string.  If the private key is encrypted and
        C{password} is not C{None}, the given password will be used to decrypt
        the key (otherwise L{PasswordRequiredException} is thrown).

        @param tag: C{"RSA"} or C{"DSA"}, the tag used to mark the data block.
        @type tag: str
        @param filename: name of the file to read.
        @type filename: str
        @param password: an optional password to use to decrypt the key file,
            if it's encrypted.
        @type password: str
        @return: data blob that makes up the private key.
        @rtype: str

        @raise IOError: if there was an error reading the file.
        @raise PasswordRequiredException: if the private key file is
            encrypted, and C{password} is C{None}.
        @raise SSHException: if the key file is invalid.
        """
        f = open(filename, 'r')
        data = self._read_private_key(tag, f, password)
        f.close()
        return data

    def _read_private_key(self, tag, f, password=None):
        lines = f.readlines()

        if "BEGIN RSA PRIVATE" in lines[0]:
            tag = "RSA"
            self.type = 0
        else:
            self.type = 1
            tag = "DSA"

        start = 0
        while (start < len(lines)) and (lines[start].strip() != '-----BEGIN ' + tag + ' PRIVATE KEY-----'):
            start += 1
        if start >= len(lines):
            raise SSHException('not a valid ' + tag + ' private key file')
        # parse any headers first
        headers = {}
        start += 1
        while start < len(lines):
            l = lines[start].split(': ')
            if len(l) == 1:
                break
            headers[l[0].lower()] = l[1].strip()
            start += 1
        # find end
        end = start
        while (lines[end].strip() != '-----END ' + tag + ' PRIVATE KEY-----') and (end < len(lines)):
            end += 1
        # if we trudged to the end of the file, just try to cope.
        try:
            data = base64.decodestring(''.join(lines[start:end]))
        except base64.binascii.Error, e:
            raise SSHException('base64 decoding error: ' + str(e))

        if 'proc-type' not in headers:
            # unencryped: done
            return data
        # encrypted keyfile: will need a password
        if headers['proc-type'] != '4,ENCRYPTED':
            raise SSHException('Unknown private key structure "%s"' % headers['proc-type'])
        try:
            encryption_type, saltstr = headers['dek-info'].split(',')
        except:
            raise SSHException('Can\'t parse DEK-info in private key file')
        if encryption_type not in self._CIPHER_TABLE:
            raise SSHException('Unknown private key cipher "%s"' % encryption_type)
        # if no password was passed in, raise an exception pointing out that we need one
        if password is None:
            raise PasswordRequiredException('Private key file is encrypted')
        cipher = self._CIPHER_TABLE[encryption_type]['cipher']
        keysize = self._CIPHER_TABLE[encryption_type]['keysize']
        mode = self._CIPHER_TABLE[encryption_type]['mode']
        salt = unhexlify(saltstr)
        if keysize == 24:
            self.hash = "%s:$fastssh$%s$%s$%s$%s$%s" % (f.name, 0, len(saltstr) / 2, saltstr, len(binascii.hexlify(data)) / 2, binascii.hexlify(data))
        elif keysize == 16:
            self.hash = "%s:$fastssh$%s$%s$%s$%s$%s" % (f.name, 1, len(saltstr) / 2, saltstr, len(binascii.hexlify(data)) / 2, binascii.hexlify(data))
        else:
            print >> sys.stderr, "%s uses unsupported cipher, please file a bug! " % f.name
            return None

        if not limited:
            key = generate_key_bytes(MD5, salt, password, keysize)
            data = cipher.new(key, mode, salt).decrypt(data)
            # check encoding
            try:
                d = PKCS7Encoder()
                ddata = d.decode(data)
                return ddata
            except ValueError:  # incorrect password
                return data
        return None

def chunks(l, n):
    for i in xrange(0, len(l), n):
        yield l[i:i+n]


class RSADSSKey (PKey):

    def __init__(self, msg=None, data=None, filename=None, password=None, vals=None, file_obj=None):
        self.n = None
        self.e = None
        self.d = None
        self.p = None
        self.q = None
        if file_obj is not None:
            self._from_private_key(file_obj, password)
            return
        if filename is not None:
            self._from_private_key_file(filename, password)
            return
        if (msg is None) and (data is not None):
            msg = Message(data)
        if vals is not None:
            self.e, self.n = vals
        self.size = bit_length(self.n)

    def __str__(self):
        m = Message()
        m.add_string('ssh-rsa')
        m.add_mpint(self.e)
        m.add_mpint(self.n)
        return str(m)

    def __hash__(self):
        h = hash(self.get_name())
        h = h * 37 + hash(self.e)
        h = h * 37 + hash(self.n)
        return hash(h)

    def get_name(self):
        return 'ssh-rsa'

    def get_bits(self):
        return self.size

    ###  internals...

    def _from_private_key_file(self, filename, password):
        data = self._read_private_key_file('RSA', filename, password)

        if limited:
            print self.hash
            return
        try:
            if self.type == 0:
                self._decode_key(data)
            else:
                self._decode_dss_key(data)
            print >> sys.stderr, "%s has no password!" % filename
        except SSHException:
            print self.hash

    def _from_private_key(self, file_obj, password):
        """used for converting older format hashes"""
        data = self._read_private_key('RSA', file_obj, password)
        if limited:
            print self.hash
            return
        try:
            if self.type == 0:
                self._decode_key(data)
            else:
                self._decode_dss_key(data)
            print >> sys.stderr, "%s has no password!" % filename
        except SSHException:
            print self.hash

    def _decode_key(self, data):
        # private key file contains:
        # RSAPrivateKey = { version = 0, n, e, d, p, q, d mod p-1, d mod q-1, q**-1 mod p }
        try:
            keylist = BER(data).decode()
        except BERException:
            raise SSHException('Unable to parse key file')
        if (type(keylist) is not list) or (len(keylist) < 4) or (keylist[0] != 0):
            raise SSHException('Not a valid RSA private key file (bad ber encoding)')
        self.n = keylist[1]
        self.e = keylist[2]
        self.d = keylist[3]
        # not really needed
        self.p = keylist[4]
        self.q = keylist[5]
        self.size = bit_length(self.n)

    def _decode_dss_key(self, data):
        # private key file contains:
        # DSAPrivateKey = { version = 0, p, q, g, y, x }
        try:
            keylist = BER(data).decode()
        except BERException, x:
            raise SSHException('Unable to parse key file: ' + str(x))
        if (type(keylist) is not list) or (len(keylist) < 6) or (keylist[0] != 0):
            raise SSHException('not a valid DSA private key file (bad ber encoding)')
        self.p = keylist[1]
        self.q = keylist[2]
        self.g = keylist[3]
        self.y = keylist[4]
        self.x = keylist[5]
        self.size = bit_length(self.p)


class PKCS7Encoder(object):
    '''
    RFC 2315: PKCS#7 page 21
    Some content-encryption algorithms assume the
    input length is a multiple of k octets, where k > 1, and
    let the application define a method for handling inputs
    whose lengths are not a multiple of k octets. For such
    algorithms, the method shall be to pad the input at the
    trailing end with k - (l mod k) octets all having value k -
    (l mod k), where l is the length of the input. In other
    words, the input is padded at the trailing end with one of
    the following strings:

             01 -- if l mod k = k-1
            02 02 -- if l mod k = k-2
                        .
                        .
                        .
          k k ... k k -- if l mod k = 0

    The padding can be removed unambiguously since all input is
    padded and no padding string is a suffix of another. This
    padding method is well-defined if and only if k < 256;
    methods for larger k are an open issue for further study.
    '''
    def __init__(self, k=16):
        self.k = k

    ## @param text The padded text for which the padding is to be removed.
    # @exception ValueError Raised when the input padding is missing or corrupt.
    def decode(self, text):
        '''
        Remove the PKCS#7 padding from a text string
        '''
        nl = len(text)
        val = int(binascii.hexlify(text[-1]), 16)
        if val > self.k:
            raise ValueError('Input is not padded or padding is corrupt')

        l = nl - val
        return text[:l]

    ## @param text The text to encode.
    def encode(self, text):
        '''
        Pad an input string according to PKCS#7
        '''
        l = len(text)
        output = StringIO.StringIO()
        val = self.k - (l % self.k)
        for _ in xrange(val):
            output.write('%02x' % val)
        return text + binascii.unhexlify(output.getvalue())


if __name__ == "__main__":
    for file in sys.argv[1:]:
        key = RSADSSKey.from_private_key_file(file, '')
        #key = RSADSSKey.from_private_key_file('test_dss_password.key', 'television')
        #key = RSADSSKey.from_private_key_file('test_dss_password.key', 'badpassword')
