This commit is contained in:
2024-12-09 18:22:38 +09:00
parent ab0cbebefc
commit c4c4547706
959 changed files with 174888 additions and 6 deletions

241
week06/hard/alice.py Normal file
View File

@ -0,0 +1,241 @@
#!/usr/bin/env python3
import argparse
import random
import asyncio
import logging
from asyncio import StreamReader, StreamWriter
from insecurelib import *
from pwn_utils.utils import read_line_safe
log = logging.getLogger(__name__)
clients = {} # task -> (reader, writer)
primes: list[int] | None = None
with open('alice_private.pem', 'rb') as f:
alice_private = f.read()
privKey = ECC.import_key(alice_private)
with open('bob_public.pem', 'rb') as f:
bob_public = ECC.import_key(f.read())
# channel class, initiates STS key exchange
class AuthenticatedChannel:
def __init__(self, reader, writer):
self.reader = reader
self.writer = writer
self.shared_key = None
def is_authenticated(self) -> bool:
return self.shared_key is not None
async def send_encrypted(self, msg: bytes):
"""Sends an encrypted message. Adds newline to the message"""
if not self.is_authenticated():
return
msg = encrypt(self.shared_key, msg)
self.writer.write(msg + b'\n')
await self.writer.drain()
async def recv_encrypted(self) -> bytes | None:
"""receives encrypted message. Returns None if no message is received"""
if not self.is_authenticated():
return None
data = await read_line_safe(self.reader)
if data is None:
return None
return decrypt(self.shared_key, data)
async def do_STS_key_exchange(self):
# pick p, g and private secret a + public part X=g^a mod p
p = STS_PRIME
g = STS_GENERATOR
a = random.randint(1, p - 1)
X = pow(g, a, mod=p)
# send p,q and public keypart
pgX = f'{p},{g},{X}\n'
self.writer.write(pgX.encode())
Ys = await read_line_safe(self.reader)
log.info(f'received "{Ys}" as Ys (public key & sig server2)')
if Ys is None:
self.writer.write('no public key received\n'.encode())
await self.writer.drain()
return None
Y, s = Ys.split(',')
Y = int(Y)
if Y >= p:
self.writer.write(f"Y ({Y}) can't be larger or equal to p ({p})!".encode())
await self.writer.drain()
return None
# calculate shared key
key = str(pow(Y, a, mod=p))
key = KDRV256(key.encode())
# decrypt and verify signature
decrypted_sig = decrypt(key, s)
if not verify(bob_public, message=f'{Y},{X}'.encode(), signature=decrypted_sig):
self.writer.write('Signature verification failed\n'.encode())
await self.writer.drain()
return None
# sign X and Y and send signature
sig = sign(privKey, f'{X},{Y}'.encode())
sig = encrypt(key, sig)
self.writer.write(sig + b'\n')
await self.writer.drain()
self.shared_key = key
async def do_session_key_DH_exchange(channel: AuthenticatedChannel) -> bytes | None:
"""
Receives initial parameters and sends own public keypart.
All communication is sent over the authenticated channel.
"""
# receive p,q and public keypart to other server (over the client) and wait for response
pgX = await channel.recv_encrypted()
if pgX is None:
return
pgX = pgX.decode().rstrip('\n')
if pgX.count(',') != 2:
await channel.send_encrypted(
'Invalid amount of arguments (expected 3; p,g,X)\n'.encode()
)
return
p, g, X = map(int, pgX.split(','))
# two checks to prevent DOSes and improve performance
if not check_int_range(p):
await channel.send_encrypted(f'{p} must be in [{0}..{MAX_PRIME}]'.encode())
return
if not check_int_range(g):
await channel.send_encrypted(f'{g} must be in [{0}..{MAX_PRIME}]'.encode())
return
# check if parameters are valid
if not is_prime(p):
await channel.send_encrypted(f'{p} is not a prime number!'.encode())
return
if not is_primitive_root(g, p):
await channel.send_encrypted(f'{g} is not a primitive root of {p}!'.encode())
return
if X >= p:
await channel.send_encrypted(f"X ({X} can't be larger or equal to p {p}!".encode())
return
# create own public/private key parts:
b = random.randint(1, p - 1)
Y = pow(g, b, mod=p)
await channel.send_encrypted(f'{Y}'.encode())
# calculate shared key
key = str(pow(X, b, mod=p))
key = KDRV256(key.encode())
return key
async def handle_client(client_reader: StreamReader, client_writer: StreamWriter):
global primes
try:
log_new_connection(client_writer)
except Exception as e:
log.error(f'EXCEPTION (get peername): {e} ({type(e)})')
return
try:
authenticated_channel = AuthenticatedChannel(client_reader, client_writer)
await authenticated_channel.do_STS_key_exchange()
if not authenticated_channel.is_authenticated():
log.info('Authenticated key exchange failed!')
return
# do session key DH exchange
session_key = await do_session_key_DH_exchange(authenticated_channel)
msg = 'Hey Bob, plz send me my f14g :-)'
encrypted_msg = encrypt(session_key, msg.encode())
await authenticated_channel.send_encrypted(encrypted_msg)
data = await authenticated_channel.recv_encrypted()
print('Received data: ', data)
if data is None:
return
decrypted_flag = decrypt(session_key, data.decode())
print(f'flag: {decrypted_flag}')
except Exception as e:
try:
error = f'something went wrong with the previous message! Error: {e}\n'
client_writer.write(error.encode())
await client_writer.drain()
except Exception:
return
def accept_client(client_reader: StreamReader, client_writer: StreamWriter):
task = asyncio.Task(handle_client(client_reader, client_writer))
clients[task] = (client_reader, client_writer)
def client_done(task):
del clients[task]
client_writer.close()
log.info('connection closed')
task.add_done_callback(client_done)
def log_new_connection(client_writer):
remote = client_writer.get_extra_info('peername')
if remote is None:
log.error('Could not get ip of client')
return
log.info(f'new connection from: {remote[0]}:{remote[1]}')
def main():
global primes
cmd = argparse.ArgumentParser()
cmd.add_argument('-p', '--port', type=int, default=20206)
args = cmd.parse_args()
# done as global list, so that we can simply pick one for every connection
primes = get_primes(MIN_PRIME, MAX_PRIME)
print('generated primes from {} to {}'.format(primes[0], primes[-1]))
# start server
loop = asyncio.get_event_loop()
f = asyncio.start_server(accept_client, host=None, port=args.port)
log.info('Server waiting for connections')
loop.run_until_complete(f)
loop.run_forever()
try:
loop.run_forever()
except KeyboardInterrupt:
pass
if __name__ == '__main__':
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s %(levelname)s [%(module)s:%(lineno)d] %(message)s',
)
# "INFO:asyncio:poll took 25.960 seconds" is annyoing
logging.getLogger('asyncio').setLevel(logging.WARNING)
main()

229
week06/hard/bob.py Normal file
View File

@ -0,0 +1,229 @@
#!/usr/bin/env python3
import argparse
import asyncio
import logging
import subprocess
import random
from asyncio import StreamReader, StreamWriter
from insecurelib import *
from pwn_utils.utils import read_line_safe
log = logging.getLogger(__name__)
clients = {} # task -> (reader, writer)
primes: list[int] | None = None
with open('bob_private.pem', 'rb') as f:
bob_private = f.read()
privKey = ECC.import_key(bob_private)
with open('alice_public.pem', 'rb') as f:
alice_public = ECC.import_key(f.read())
# channel class, receives initial STS key exchange params
class AuthenticatedChannel:
def __init__(self, reader, writer):
self.reader = reader
self.writer = writer
self.shared_key = None
def is_authenticated(self) -> bool:
return self.shared_key is not None
async def send_encrypted(self, msg: bytes):
"""Sends an encrypted message. Adds newline to the message"""
if not self.is_authenticated():
return
msg = encrypt(self.shared_key, msg)
self.writer.write(msg + b'\n')
await self.writer.drain()
async def recv_encrypted(self) -> bytes | None:
"""receives encrypted message. Returns None if no message is received"""
if not self.is_authenticated():
return None
data = await read_line_safe(self.reader)
if data is None:
return None
return decrypt(self.shared_key, data)
async def do_STS_key_exchange(self):
# receive p,q and public keypart to other server (over the client) and wait for response
pgX = await read_line_safe(self.reader)
if pgX is None:
return
if pgX.count(',') != 2:
self.writer.write('Invalid amount of arguments (expected 3; p,g,X)\n'.encode())
await self.writer.drain()
return
p, g, X = map(int, pgX.split(','))
# primality and size checks not necessary since fixed values from RFC 3526 are used for STS key exchange
# create own public/private key parts:
b = random.randint(1, p - 1)
Y = pow(g, b, mod=p)
# calculate shared key
key = str(pow(X, b, mod=p))
key = KDRV256(key.encode())
# sign Y and X and send Y and signature
sig = sign(privKey, f'{Y},{X}'.encode())
sig = encrypt(key, sig)
message = f'{Y},{sig.decode()}\n'.encode()
self.writer.write(message)
await self.writer.drain()
answer_sig = await read_line_safe(self.reader)
if answer_sig is None:
return
decrypted_sig = decrypt(key, answer_sig)
if not verify(alice_public, message=f'{X},{Y}'.encode(), signature=decrypted_sig):
self.writer.write('Signature verification failed\n'.encode())
await self.writer.drain()
return None
self.shared_key = key
async def do_session_key_DH_exchange(channel: AuthenticatedChannel) -> bytes | None:
"""
Initiates session key DH exchange.
All communication is sent over the authenticated channel.
"""
# pick p, g and private secret a + public part X=g^a mod p
p = random.choice(primes)
g = random.choice(primroots(p))
a = random.randint(1, p - 1)
X = pow(g, a, mod=p)
# send p,q and public keypart
pgX = f'{p},{g},{X}'
await channel.send_encrypted(pgX.encode())
Y = await channel.recv_encrypted()
log.info(f'received "{Y}" as Y (public key)')
if Y is None:
await channel.send_encrypted('no public key received'.encode())
return None
Y = int(Y.decode().rstrip('\n'))
if Y >= p:
await channel.send_encrypted(f"Y ({Y}) can't be larger or equal to p ({p})!".encode())
return None
# calculate shared key
key = str(pow(Y, a, mod=p))
key = KDRV256(key.encode())
return key
async def handle_client(client_reader: StreamReader, client_writer: StreamWriter):
try:
log_new_connection(client_writer)
except Exception as e:
log.error(f'EXCEPTION (get peername): {e} ({type(e)})')
return
try:
authenticated_channel = AuthenticatedChannel(client_reader, client_writer)
await authenticated_channel.do_STS_key_exchange()
if not authenticated_channel.is_authenticated():
log.info('Authenticated key exchange failed!')
return
# do session key DH exchange
session_key = await do_session_key_DH_exchange(authenticated_channel)
message = await authenticated_channel.recv_encrypted()
if message is None:
return
decrypted_msg = decrypt(session_key, message.decode())
print('decrypted_msg: ', decrypted_msg)
if decrypted_msg.decode() != 'Hey Bob, plz send me my f14g :-)':
await authenticated_channel.send_encrypted(
'Critical error: unknown message'.encode()
)
return
flag = subprocess.check_output('flag')
encrypted_flag = encrypt(session_key, flag)
await authenticated_channel.send_encrypted(encrypted_flag)
except UnicodeDecodeError as e:
try:
client_writer.write(
f"UnicodeDecodeError: {e} (yep, this leaks a lot about the plaintext, but you don't need it ;))\n".encode()
)
except Exception:
log.exception("couldn't handle UnicodeDecodeError")
return
except Exception as e:
print('Exception: ', e)
pass
def accept_client(client_reader: StreamReader, client_writer: StreamWriter):
task = asyncio.Task(handle_client(client_reader, client_writer))
clients[task] = (client_reader, client_writer)
def client_done(task):
del clients[task]
client_writer.close()
log.info('connection closed')
task.add_done_callback(client_done)
def log_new_connection(client_writer):
remote = client_writer.get_extra_info('peername')
if remote is None:
log.error('Could not get ip of client')
return
log.info(f'new connection from: {remote[0]}:{remote[1]}')
def main():
global primes
cmd = argparse.ArgumentParser()
cmd.add_argument('-p', '--port', type=int, default=20306)
args = cmd.parse_args()
# done as global list, so that we can simply pick one for every connection
primes = get_primes(MIN_PRIME, MAX_PRIME)
print('generated primes from {} to {}'.format(primes[0], primes[-1]))
# start server
loop = asyncio.get_event_loop()
f = asyncio.start_server(accept_client, host=None, port=args.port)
log.info('Server waiting for connections')
loop.run_until_complete(f)
loop.run_forever()
try:
loop.run_forever()
except KeyboardInterrupt:
pass
if __name__ == '__main__':
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s %(levelname)s [%(module)s:%(lineno)d] %(message)s',
)
# "INFO:asyncio:poll took 25.960 seconds" is annyoing
logging.getLogger('asyncio').setLevel(logging.WARNING)
main()

65
week06/hard/client.py Normal file
View File

@ -0,0 +1,65 @@
#!/usr/bin/env python3
import random
import socket
# Kerckhoffs principle for the win
# here are all the crypto primitives Alice and Bob are using
from insecurelib import KDRV256, HMAC, encrypt, decrypt
# Fill in the right target here
HOST = 'this.is.not.a.valid.domain' # TODO
PORT1 = 20008
PORT2 = 20108
# note the numbers you encounter may be small for demonstration purposes.
# Anyway, please do NOT brute force.
# debug the data sent over the secure channel
# we don't know the key, there is not much to debug here, ...
def debug_secure_channel(s1, s2, data: str):
data = data.rstrip('\n') # remove trailing newline
if len(data) >= 1024:
print(f"from {s1} to {s2}: '{data[:1024]}...'")
else:
print(f"from {s1} to {s2}: '{data}...'")
iv, ciphertext, mac = data.split(',')
assert len(iv) == 16 * 2 # a hexlified byte is two bytes long, the IV should be 16 bytes
assert (
len(ciphertext) % (16 * 2) == 0
) # a hexlified byte is two bytes long, AES block size is 128 bit (16 byte)
assert (
len(mac) == 16 * 2
) # a quite short MAC. Hint: you still don't want to brute force it!
def main():
# We connect to Alice and Bob and relay their messages.
# They send all their communication over us. How convenient :-)
# Dolev-Yao attacker model without any low-level effort.
s1 = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s1.connect((HOST, PORT1))
s1f = s1.makefile('rw') # file abstraction for the sockets
s2 = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s2.connect((HOST, PORT2))
s2f = s2.makefile('rw')
# A -> B: p,g,X
data = s1f.readline().rstrip('\n')
print(f"from s1 to s2: '{data}'")
p, g, X = map(int, data.split(','))
# TODO: get the flag
s1f.close()
s2f.close()
s1.close()
s2.close()
if __name__ == '__main__':
main()

Binary file not shown.

View File

@ -0,0 +1,23 @@
from Crypto.PublicKey import ECC
def generate_keys() -> tuple[str, str]:
key = ECC.generate(curve='ed25519')
private_key = key.export_key(format='PEM')
public_key = key.public_key().export_key(format='PEM')
return private_key, public_key
def write_keys(name: str, private_key: str, public_key: str):
with open(f'{name}_private.pem', 'wb') as f:
f.write(private_key.encode())
with open(f'{name}_public.pem', 'wb') as f:
f.write(public_key.encode())
if __name__ == '__main__':
privA, pubA = generate_keys()
privB, pubB = generate_keys()
write_keys('alice', privA, pubA)
write_keys('bob', privB, pubB)
print('Keys generated and written to files')

299
week06/hard/insecurelib.py Normal file
View File

@ -0,0 +1,299 @@
#!/usr/bin/env python3
"""
WARNING: cryptolib only for demonstration purposes
Never use the following crypto functions in any production code!
The code features:
* Timing side channels
* Only works for horribly small numbers
* No permormance at all
"""
import math
from Crypto.Hash import HMAC as realHMAC
from binascii import hexlify, unhexlify
from Crypto.Cipher import AES
from Crypto.Hash import SHA256
from random import getrandbits
from Crypto.PublicKey import ECC
from Crypto.PublicKey.ECC import EccKey
from Crypto.Signature import eddsa
# used for the session key exchange where performance matters
MIN_PRIME = 10000
MAX_PRIME = 20000
# for STS key exchange use secure parameters from RFC3526 (group #14): https://www.ietf.org/rfc/rfc3526.txt
STS_GENERATOR = 2
STS_PRIME = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF
def check_int_range(i: int):
return 0 <= i <= MAX_PRIME
def is_prime(num: int):
"""
basic primality test for the given number
"""
# WARNING: horribly insecure example code!
for j in range(2, int(math.sqrt(num) + 1)):
if (num % j) == 0:
return False
return True
def get_primes(min_prime: int, max_prime: int):
"""
calculate primes that are in the range [min_prime, max_prime]
"""
# WARNING: horribly insecure example code!
# http://stackoverflow.com/questions/2068372/fastest-way-to-list-all-primes-below-n/3035188#3035188
def primes(n):
"""Returns a list of primes < n"""
sieve = [True] * n
for i in range(3, int(n**0.5) + 1, 2):
if sieve[i]:
sieve[i * i :: 2 * i] = [False] * ((n - i * i - 1) // (2 * i) + 1)
return [2] + [i for i in range(3, n, 2) if sieve[i]]
primes_up_to = primes(max_prime + 1)
# what takes so long is the debug here
for p in primes_up_to:
assert is_prime(p)
return [p for p in primes_up_to if p >= min_prime]
def primefactors(x: int):
"""
compute primefactors of a number
"""
# WARNING: horribly insecure example code!
factorlist = []
loop = 2
while loop <= x:
if x % loop == 0:
x /= loop
factorlist.append(loop)
else:
loop += 1
return factorlist
def primroots(p: int):
"""
given one prime number, compute all primitive roots of it
"""
# WARNING: horribly insecure example code!
g = get_primitive_root(p) # get first primitive root
znorder = p - 1
is_coprime = lambda x: math.gcd(x, znorder) == 1
good_odd_integers = filter(is_coprime, range(1, p, 2))
all_primroots = [pow(g, k, mod=p) for k in good_odd_integers]
all_primroots.sort()
return all_primroots
def is_primitive_root(g: int, p: int):
"""
test if DH parameters are correct
"""
# WARNING: horribly insecure example code!
phi = p - 1
for factor in set(primefactors(phi)):
# if pow(m,int(phi/factor))%p equals one for any prime factor, it isn't a primitive root, otherwise it is
if pow(g, int(phi / factor), mod=p) == 1:
return False
return True
def get_primitive_root(p: int):
"""
compute a primitive root of a prime p
"""
# WARNING: horribly insecure example code!
# p is prime, so phi(p) is p-1
phi = p - 1
# check all 2<=g<p if they are a primitive root of p, stop at the first one
for g in range(2, p):
is_prim_root = True
for factor in set(primefactors(phi)):
# if pow(g,int(phi/factor))%p equals one for any prime factor, it isn't a primitive root, otherwise it is
if pow(g, int(phi / factor), mod=p) == 1:
is_prim_root = False
break
# if we found one root, stop
if is_prim_root:
return g
def KDRV256(b: bytes):
"""
Key Derivation Function.
returns a 256-bit key
"""
h = SHA256.new()
h.update(b)
return h.digest()
def HMAC(k: bytes, b: bytes):
"""
returns: 128-bit MAC
128-bit MAC is very short but should be enough in comparison with
those horribly small numbers we use in NetSec
"""
h = realHMAC.new(k)
h.update(b)
return h.digest()
def encrypt(key: bytes, plaintext: bytes) -> bytes:
"""
key: 256 bit
first 128 bit for AES
last 128 bit for MAC
plaintext: in binary
returns: IV,ciphertext,MAC
binary
"""
# WARNING: horribly insecure example code!
assert len(key) == 32 # AES-128 + 128-bit MAC
key_enc = key[:16]
key_int = key[16:]
# Cryptographically insecure randomness in IV
iv = bytes(getrandbits(8) for _ in range(AES.block_size))
# add padding
# http://stackoverflow.com/questions/14179784/python-encrypting-with-pycrypto-aes
padding_length = 16 - (len(plaintext) % 16)
plaintext += bytes([padding_length]) * padding_length
# encrypt plaintext
cipher = AES.new(key_enc, AES.MODE_CBC, iv)
ciphertext = cipher.encrypt(plaintext)
mac = HMAC(key_int, ciphertext)
message = hexlify(iv) + b';' + hexlify(ciphertext) + b';' + hexlify(mac)
return message
def decrypt(key: bytes, message: str) -> bytes:
"""
key: 256 bit
message: IV,ciphertext,MAC
where IV,ciphertext,MAC are hexlified strings
string
example: "696e46845a0e69c18747b76fc087d3b5,15...0c,448ca984c4dd3f3454d1f311443802ed"
returns: decrypted plaintext
binary
"""
# WARNING: horribly insecure example code!
assert len(key) == 32 # AES-128 + 128-bit MAC
key_enc = key[:16]
key_int = key[16:]
assert not message.endswith('\n'), 'message should not end with a newline!'
try:
iv, ciphertext, mac = message.split(';')
iv = unhexlify(iv)
ciphertext = unhexlify(ciphertext)
mac = unhexlify(mac)
assert len(mac) == 16
except Exception as e:
raise Exception(f'iv;cipertext;mac not readable in message "{message}": {e}')
# check MAC
# TODO: timing side channel? NetSec hint: probably not practically exploitable
mac_computed = HMAC(key_int, ciphertext)
if mac_computed != mac:
raise Exception('MAC verification error')
# decrypt ciphertext
cipher = AES.new(key_enc, AES.MODE_CBC, iv)
plaintext = cipher.decrypt(ciphertext)
# remove padding
plaintext = plaintext[: -plaintext[-1]]
return plaintext
def sign(key: EccKey, message: bytes) -> bytes:
"""
key: private key
message: in binary
returns: signature
"""
signer = eddsa.new(key, 'rfc8032')
signature = signer.sign(message)
return signature
def verify(key: EccKey, message: bytes, signature: bytes) -> bool:
"""
key: public key
message: in binary
signature: in binary
returns: True if signature is valid
"""
verifier = eddsa.new(key, 'rfc8032')
try:
verifier.verify(message, signature)
return True
except ValueError:
return False
def test():
assert len(HMAC(b'keyXkeyXkeyXkeyX', 'foobar'.encode())) == 16
assert len(KDRV256(b'creating long keys from low-entropy input is still weak')) == 32
assert decrypt(b'keyX' * 8, encrypt(b'keyX' * 8, b'foobar').decode()) == b'foobar'
assert (
decrypt(b'keyX' * 8, encrypt(b'keyX' * 8, b'foobar' * 32).decode()) == b'foobar' * 32
)
assert (
decrypt(b'keyX' * 8, encrypt(b'keyX' * 8, b'X' * 16).decode()) == b'XXXXXXXXXXXXXXXX'
)
assert decrypt(b'keyX' * 8, encrypt(b'keyX' * 8, b'').decode()) == b''
key = ECC.generate(curve='ed25519')
private_key = key
public_key = key.public_key()
message = b'hello world'
signature = sign(private_key, message)
assert verify(public_key, message, signature)
assert not verify(public_key, message + b'!', signature)
print('tests ok')
if __name__ == '__main__':
test()