#!/usr/bin/env python
import hashlib, os, random, string
from signal import alarm
from secret import flag
from binascii import unhexlify

ALARM_TIME = [80,160,240,900]
NUMCIPHERS = [2,2,4,52]

class BlockCipher():
    def __init__(self,key,r):
        assert(len(key) == r * 4)
        sbox = [i for i in range(256)]
        random.shuffle(sbox)
        self.sbox = sbox
        pbox = []
        for i in range(4):
            tmp = [j for j in range(32)]
            random.shuffle(tmp)
            pbox.append(tmp)
        self.pbox = pbox
        self.r = r
        subkeys = []
        for i in range(0,len(key),4):
            subkeys.append(int.from_bytes(key[i:i+4], 'big'))
        self.subkeys = subkeys

    def encrypt(self,plain):
        assert(len(plain) == 8)
        L = int.from_bytes(plain[:4], 'big')
        R = int.from_bytes(plain[4:], 'big')
        for i in range(self.r):
            L, R = R, L ^ BlockCipher.F(self.sbox, self.pbox[i], R ^ self.subkeys[i])
        L, R = R, L
        result = L.to_bytes(4, 'big') + R.to_bytes(4, 'big')
        return result

    def decrypt(self,cipher):
        assert(len(cipher) == 8)
        L = int.from_bytes(cipher[:4], 'big')
        R = int.from_bytes(cipher[4:], 'big')
        for i in range(self.r):
            L, R = R, L ^ BlockCipher.F(self.sbox, self.pbox[self.r - 1 - i], R ^ self.subkeys[self.r - 1 - i])
        L, R = R, L
        result = L.to_bytes(4, 'big') + R.to_bytes(4, 'big')
        return result

    def setSbox(self,sbox):
        self.sbox = sbox

    def getSbox(self):
        return self.sbox
    
    def setPbox(self,pbox):
        self.pbox = pbox

    def getPbox(self):
        return self.pbox

    @staticmethod
    def F(sbox,pbox,x):
        x = BlockCipher.S(sbox,x)
        x = BlockCipher.P(pbox,x)
        return x

    @staticmethod
    def S(sbox,x):
        B = [(x >> 24) & 0xff,(x >> 16) & 0xff,(x >> 8) & 0xff,x & 0xff]
        B = [sbox[i] for i in B]
        return (B[0] << 24) | (B[1] << 16) | (B[2] << 8) | B[3]

    @staticmethod
    def P(pbox,x):
        x = [int(i) for i in bin(x)[2:].rjust(32,"0")]
        result = 0
        for i in range(len(x)):
            if x[i] == 1:
                result |= 1 << pbox[i]
        return result

def challenge(level):
    key = os.urandom(level * 4)
    
    cipher = BlockCipher(key,level)
    sbox = cipher.getSbox()
    print("[*] The sbox is : " + str(sbox))
    pbox = cipher.getPbox()
    print("[*] The pbox is : " + str(pbox))
    randomPlain = os.urandom(8)
    randomCipher = cipher.encrypt(randomPlain)
    print("[*] The randomCipher is : " + randomCipher.hex())
    for _ in range(NUMCIPHERS[level - 1]):
        p = unhexlify(input("[*] Input your plain : "))
        c = cipher.encrypt(p)
        print("[*] The cipher is : " + c.hex())
    guessRandomPlain = unhexlify(input("[*] Now tell me the randomPlain : "))

    if guessRandomPlain == randomPlain:
        return True
    else:
        return False

def proof_of_work():
    random.seed(os.urandom(8))
    proof = ''.join([random.choice(string.ascii_letters+string.digits) for _ in range(20)])
    digest = hashlib.sha256(proof.encode()).hexdigest()
    print("sha256(XXXX+%s) == %s" % (proof[4:],digest))
    print('Give me XXXX:')
    x = input()
    if len(x) != 4 or hashlib.sha256((x + proof[4:]).encode()).hexdigest() != digest: 
        return False
    return True

def main():
    alarm(60)
    if not proof_of_work():
        return
    try:
        count = 0
        levels = [i for i in range(1,5)]
        random.shuffle(levels)
        print("[*] Welcome to the challnge!")
        for level in levels:
            alarm(ALARM_TIME[level - 1])
            print("[*] Challenge " + str(count + 1))
            isPassed = challenge(level)
            if isPassed:
                count += 1
                print("[*] Congratulations on passing this challenge!")
            else:
                print("[-] Sorry you are out")
                break
        if count == len(levels):
            print("[*] Ok here is your flag : " + flag)
        else:
            print("[-] See you next time")
    except:
        print("Error!")

if __name__ == "__main__":
    main()