import os
from aes import AES
import hashlib

# A function that does nothing
no_op = lambda *x: None

def main():    
    k = os.urandom(16)
    cipher = AES(k)
    secret = b''.join(k[i:i+1]*4 for i in range(16))

    flag = os.environ.get('FLAG', 'hkcert21{***********************REDACTED***********************}').encode()
    assert len(flag) == 64

    flag = b''.join([cipher.encrypt(flag[i:i+16]) for i in range(0, 64, 16)])
    print(f'Hey. This is the encrypted flag you gotta decrypt: {flag.hex()}.')

    options = ['ark', 'sb', 'sr', 'mc']
    suboptions = ['data', 'secret']
    
    for _ in range(128):
        [option, suboption, *more] = input('> ').split(' ')
        if option not in options: raise Exception('invalid option!')
        if suboption not in suboptions: raise Exception('invalid suboption!')
        
        if suboption == 'secret':
            options.remove(option)
            message = secret
        else:
            message = bytes.fromhex(more[0])
            if len(message) != 16: raise Exception('invalid length!')
            message = message * 4

        if option == 'ark':
            # Stage 1: "AddRoundKey" does nothing
            cipher = AES(k)
            cipher._add_round_key = no_op
            ciphertext = cipher.encrypt(message[0:16])
        elif option == 'sb':
            # Stage 2: "SubBytes" does nothing
            cipher = AES(k)
            cipher._sub_bytes = no_op
            ciphertext = cipher.encrypt(message[16:32])
        elif option == 'sr':
            # Stage 3: "ShiftRows" does nothing
            cipher = AES(k)
            cipher._shift_rows = no_op
            ciphertext = cipher.encrypt(message[32:48])
        elif option == 'mc':
            # Stage 4: "MixColumns" does nothing
            cipher = AES(k)
            cipher._mix_columns = no_op
            ciphertext = cipher.encrypt(message[48:64])

        print(ciphertext.hex())

if __name__ == '__main__':
    main()

