import os
import random
from Crypto.Cipher import AES
from Crypto.Util.number import isPrime as is_prime
from Crypto.Util.Padding import pad

# 256 bits for random-number generator
N = 0xcdc21452d0d82fbce447a874969ebb70bcc41a2199fbe74a2958d0d280000001
G = 0x5191654c7d85905266b0a88aea88f94172292944674b97630853f919eeb1a070
H = 0x7468657365206e756d6265727320617265206f6620636f757273652073757321

# More challenge-specific parameters
E     = 65537 # The public modulus
CALLS = 65537 # The number of operations allowed

# Generate a 512-bit prime
def generate_prime(seed):
    random.seed(seed)
    while True:
        p = random.getrandbits(512) | (1<<511)
        if p % E == 1: continue
        if not is_prime(p): continue
        return p

# Defines a 1024-bit RSA key
class Key:
    def __init__(self, p, q):
        self.p = p
        self.q = q
        self.n = p*q
        self.e = E
        phi = (p-1) * (q-1)
        self.d = pow(self.e, -1, phi)

    def encrypt(self, m):
        return pow(m, self.e, self.n)
    
    def decrypt(self, c):
        return pow(c, self.d, self.n)

# Defines an user
class User:
    def __init__(self, master_secret):
        self.master_secret = master_secret
        self.key = None

    def generate_key(self):
        id = random.getrandbits(256)
        self.key = Key(
            generate_prime(self.master_secret + int.to_bytes(pow(G, id, N), 32, 'big')),
            generate_prime(self.master_secret + int.to_bytes(pow(H, id, N), 32, 'big'))
        )
    
    def send(self, m):
        if self.key is None: raise Exception('no key is defined!')

        m = int(m, 16)
        print(hex(self.key.encrypt(m)))
    
    def get_secret(self):
        if self.key is None: raise Exception('no key is defined!')

        m = int.from_bytes(self.master_secret, 'big')
        print(hex(self.key.encrypt(m)))


def main():
    flag = os.environ.get('FLAG', 'hkcert21{***REDACTED***}')
    flag = pad(flag.encode(), 16)

    master_secret = os.urandom(32)
    admin = User(master_secret)

    for _ in range(CALLS):
        command = input('[cmd] ').split(' ')

        try:
            if command[0] == 'send':
                # Encrypts a hexed message
                admin.send(command[1])
            elif command[0] == 'pkey':
                # Refreshs a new set of key
                admin.generate_key()
            elif command[0] == 'backup':
                # Gets the encrypted master secret
                admin.get_secret()
            elif command[0] == 'flag':
                cipher = AES.new(master_secret, AES.MODE_CBC, b'\0'*16)
                encrypted_flag = cipher.encrypt(flag)
                print(encrypted_flag.hex())
        except Exception as err:
            raise err
            print('nope')

if __name__ == '__main__':
    main()

'''
Side-note: I knew this is _very_ similar to "Long Story Short" in HKCERT CTF 2021, but rest assured that this is a different challenge.
...:)
'''
