265 lines
8.2 KiB
Python
265 lines
8.2 KiB
Python
#!/usr/bin/env python3
|
|
import argparse
|
|
import random
|
|
import asyncio
|
|
import logging
|
|
from asyncio import StreamReader, StreamWriter
|
|
|
|
from insecurelib import *
|
|
|
|
async def read_line_safe(reader):
|
|
"""
|
|
Simple implementation to read a line from an async reader
|
|
Mimics the original read_line_safe functionality
|
|
"""
|
|
try:
|
|
line = await reader.readline()
|
|
return line.decode().strip()
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def log_error(e, client_writer=None):
|
|
"""
|
|
Basic error logging function
|
|
"""
|
|
print(f"Error occurred: {e}")
|
|
if client_writer:
|
|
try:
|
|
client_writer.write(f"Error: {str(e)}\n".encode())
|
|
except Exception:
|
|
print("Could not send error to client")
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
clients = {} # task -> (reader, writer)
|
|
|
|
primes: list[int] | None = None
|
|
with open('alice_private.pem', 'rb') as f:
|
|
alice_private = f.read()
|
|
privKey = ECC.import_key(alice_private)
|
|
with open('bob_public.pem', 'rb') as f:
|
|
bob_public = ECC.import_key(f.read())
|
|
|
|
|
|
# channel class, initiates STS key exchange
|
|
class AuthenticatedChannel:
|
|
def __init__(self, reader, writer):
|
|
self.reader = reader
|
|
self.writer = writer
|
|
self.shared_key = None
|
|
|
|
def is_authenticated(self) -> bool:
|
|
return self.shared_key is not None
|
|
|
|
async def send_encrypted(self, msg: bytes):
|
|
"""Sends an encrypted message. Adds newline to the message"""
|
|
if not self.is_authenticated():
|
|
return
|
|
msg = encrypt(self.shared_key, msg)
|
|
self.writer.write(msg + b'\n')
|
|
await self.writer.drain()
|
|
|
|
async def recv_encrypted(self) -> bytes | None:
|
|
"""receives encrypted message. Returns None if no message is received"""
|
|
if not self.is_authenticated():
|
|
return None
|
|
data = await read_line_safe(self.reader)
|
|
if data is None:
|
|
return None
|
|
return decrypt(self.shared_key, data)
|
|
|
|
async def do_STS_key_exchange(self):
|
|
# pick p, g and private secret a + public part X=g^a mod p
|
|
p = STS_PRIME
|
|
g = STS_GENERATOR
|
|
a = random.randint(1, p - 1)
|
|
X = pow(g, a, mod=p)
|
|
|
|
# send p,q and public keypart
|
|
pgX = f'{p},{g},{X}\n'
|
|
self.writer.write(pgX.encode())
|
|
|
|
Ys = await read_line_safe(self.reader)
|
|
log.info(f'received "{Ys}" as Ys (public key & sig server2)')
|
|
|
|
if Ys is None:
|
|
self.writer.write('no public key received\n'.encode())
|
|
await self.writer.drain()
|
|
return None
|
|
|
|
Y, s = Ys.split(',')
|
|
Y = int(Y)
|
|
if Y >= p:
|
|
self.writer.write(f"Y ({Y}) can't be larger or equal to p ({p})!".encode())
|
|
await self.writer.drain()
|
|
return None
|
|
|
|
# calculate shared key
|
|
print("calculating key now")
|
|
key = str(pow(Y, a, mod=p))
|
|
key = KDRV256(key.encode())
|
|
# decrypt and verify signature
|
|
decrypted_sig = decrypt(key, s)
|
|
if not verify(bob_public, message=f'{Y},{X}'.encode(), signature=decrypted_sig):
|
|
self.writer.write('Signature verification failed\n'.encode())
|
|
await self.writer.drain()
|
|
return None
|
|
# sign X and Y and send signature
|
|
sig = sign(privKey, f'{X},{Y}'.encode())
|
|
sig = encrypt(key, sig)
|
|
self.writer.write(sig + b'\n')
|
|
print("finished the do_STS_key_exchange")
|
|
await self.writer.drain()
|
|
|
|
self.shared_key = key
|
|
|
|
|
|
async def do_session_key_DH_exchange(channel: AuthenticatedChannel) -> bytes | None:
|
|
"""
|
|
Receives initial parameters and sends own public keypart.
|
|
All communication is sent over the authenticated channel.
|
|
"""
|
|
# receive p,q and public keypart to other server (over the client) and wait for response
|
|
pgX = await channel.recv_encrypted()
|
|
|
|
if pgX is None:
|
|
return
|
|
pgX = pgX.decode().rstrip('\n')
|
|
|
|
if pgX.count(',') != 2:
|
|
await channel.send_encrypted(
|
|
'Invalid amount of arguments (expected 3; p,g,X)\n'.encode()
|
|
)
|
|
return
|
|
|
|
p, g, X = map(int, pgX.split(','))
|
|
print("p, g, x is here: ", (p,g,X))
|
|
# two checks to prevent DOSes and improve performance
|
|
if not check_int_range(p):
|
|
await channel.send_encrypted(f'{p} must be in [{0}..{MAX_PRIME}]'.encode())
|
|
return
|
|
if not check_int_range(g):
|
|
await channel.send_encrypted(f'{g} must be in [{0}..{MAX_PRIME}]'.encode())
|
|
return
|
|
print("two checks to prevent doses and improve performance finished")
|
|
# check if parameters are valid
|
|
if not is_prime(p):
|
|
await channel.send_encrypted(f'{p} is not a prime number!'.encode())
|
|
return
|
|
|
|
if not is_primitive_root(g, p):
|
|
await channel.send_encrypted(f'{g} is not a primitive root of {p}!'.encode())
|
|
return
|
|
|
|
if X >= p:
|
|
await channel.send_encrypted(f"X ({X} can't be larger or equal to p {p}!".encode())
|
|
return
|
|
print("check if parameters are valid finished")
|
|
# create own public/private key parts:
|
|
b = random.randint(1, p - 1)
|
|
Y = pow(g, b, mod=p)
|
|
print("sending encryption")
|
|
await channel.send_encrypted(f'{Y}'.encode())
|
|
print("sending encyrption finished")
|
|
# calculate shared key
|
|
key = str(pow(X, b, mod=p))
|
|
key = KDRV256(key.encode())
|
|
return key
|
|
|
|
|
|
async def handle_client(client_reader: StreamReader, client_writer: StreamWriter):
|
|
global primes
|
|
try:
|
|
log_new_connection(client_writer)
|
|
except Exception as e:
|
|
log.error(f'EXCEPTION (get peername): {e} ({type(e)})')
|
|
return
|
|
|
|
try:
|
|
authenticated_channel = AuthenticatedChannel(client_reader, client_writer)
|
|
await authenticated_channel.do_STS_key_exchange()
|
|
if not authenticated_channel.is_authenticated():
|
|
log.info('Authenticated key exchange failed!')
|
|
return
|
|
|
|
# do session key DH exchange
|
|
session_key = await do_session_key_DH_exchange(authenticated_channel)
|
|
|
|
msg = 'Hey Bob, plz send me my f14g :-)'
|
|
encrypted_msg = encrypt(session_key, msg.encode())
|
|
print("sending encypretd message about igving me flag")
|
|
await authenticated_channel.send_encrypted(encrypted_msg)
|
|
|
|
data = await authenticated_channel.recv_encrypted()
|
|
print('Received data: ', data)
|
|
if data is None:
|
|
return
|
|
|
|
decrypted_flag = decrypt(session_key, data.decode())
|
|
print(f'flag: {decrypted_flag}')
|
|
|
|
except Exception as e:
|
|
try:
|
|
error = f'something went wrong with the previous message! Error: {e}\n'
|
|
client_writer.write(error.encode())
|
|
await client_writer.drain()
|
|
except Exception:
|
|
return
|
|
|
|
|
|
def accept_client(client_reader: StreamReader, client_writer: StreamWriter):
|
|
task = asyncio.Task(handle_client(client_reader, client_writer))
|
|
clients[task] = (client_reader, client_writer)
|
|
|
|
def client_done(task):
|
|
del clients[task]
|
|
client_writer.close()
|
|
log.info('connection closed')
|
|
|
|
task.add_done_callback(client_done)
|
|
|
|
|
|
def log_new_connection(client_writer):
|
|
remote = client_writer.get_extra_info('peername')
|
|
if remote is None:
|
|
log.error('Could not get ip of client')
|
|
return
|
|
log.info(f'new connection from: {remote[0]}:{remote[1]}')
|
|
|
|
|
|
def main():
|
|
global primes
|
|
|
|
cmd = argparse.ArgumentParser()
|
|
cmd.add_argument('-p', '--port', type=int, default=20206)
|
|
args = cmd.parse_args()
|
|
|
|
# done as global list, so that we can simply pick one for every connection
|
|
primes = get_primes(MIN_PRIME, MAX_PRIME)
|
|
print('generated primes from {} to {}'.format(primes[0], primes[-1]))
|
|
|
|
# start server
|
|
loop = asyncio.get_event_loop()
|
|
f = asyncio.start_server(accept_client, host=None, port=args.port)
|
|
log.info('Server waiting for connections')
|
|
loop.run_until_complete(f)
|
|
loop.run_forever()
|
|
|
|
try:
|
|
loop.run_forever()
|
|
except KeyboardInterrupt:
|
|
pass
|
|
|
|
|
|
if __name__ == '__main__':
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s %(levelname)s [%(module)s:%(lineno)d] %(message)s',
|
|
)
|
|
|
|
# "INFO:asyncio:poll took 25.960 seconds" is annyoing
|
|
logging.getLogger('asyncio').setLevel(logging.WARNING)
|
|
|
|
main()
|