diff --git a/agent.py b/agent.py index cd3c799..e6567d4 100644 --- a/agent.py +++ b/agent.py @@ -1,64 +1,199 @@ -import socket -import ssl -import nftables -import json -import psutil +############################################################################################################ +# author Aaron Moser + +############################################################################################################ +# +import platform,socket,re,uuid,json,psutil,logging,ssl,nftables,queue,threading + +############################################################################################################ +# Boolean which indicates if the server thread should keep running or not. +server_running = False + +# Create message queue for messages received from client. +receive_message_queue = queue.Queue() + +############################################################################################################ +# +def start_server(host_ip, port, logs_path): + server_running = True + + # Create an SSL context. + ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_context.minimum_version = ssl.TLSVersion.TLSv1_3 + ssl_context.load_cert_chain(certfile='server.crt', keyfile='server.key') + + # While no severe error occured, try to keep server running and connect to upcoming + while server_running: + # Create a socket and bind it to the specified address and port. + server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server_socket.bind((host_ip, port)) + server_socket.listen(1) + print(f"Server listening on {host_ip}:{port}") + + # Accept incoming connections. + client_socket, client_address = server_socket.accept() + print(f"Accepted connection from {client_address}") + + # Wrap the client socket with SSL using the SSL context. + ssl_socket = ssl_context.wrap_socket(client_socket, server_side=True) + + # Create and start thread which only puts received messages + # from client into message queue. + receive_thread = threading.Thread( + target=receive_messages_and_put_into_message_queue, + args=(receive_message_queue, ssl_socket) + ) + receive_thread.start() + + # Create and start thread which takes messages from receive queue + # and calls functions depending on which message was received. + response_thread = threading.Thread( + target=process_received_messages_and_send_response, + args=(receive_message_queue, ssl_socket, logs_path) + ) + response_thread.start() + + # Wait until both threads are finished. + receive_thread.join() + response_thread.join() + + # Close the connection. + ssl_socket.close() + + print("Agent has been stopped.") + +############################################################################################################ +# +def get_host_ip(): + host = '127.0.0.1' + host_input_valid = False + + while not host_input_valid: + host_input = input("Enter server ip:") + try: + host = str(host_input) + host_input_valid = True + except ValueError: + print("That is not a valid IP.") + return host + +############################################################################################################ +# +def get_port(): + port = 5000 + port_input_valid = False + while not port_input_valid: + port_input = input("Enter server port:") + try: + port = int(port_input) + port_input_valid = True + except ValueError: + print("That is not a valid number.") + return port + +def get_syslog_path(): + syslog_path = "/var/log/syslog" + path_input_valid = False + while not path_input_valid: + syslog_path_input = input("Enter syslog path:") + try: + syslog_path = str(syslog_path_input) + path_input_valid = True + except: + print("Invalid path.") + return syslog_path + +############################################################################################################ +# Read data from the client and put it into message queue. +def receive_messages_and_put_into_message_queue(message_queue, ssl_socket): + error_occured = False + while not error_occured and server_running: + try: + data = ssl_socket.recv(1024).decode('utf-8') + message_queue.put(data) + print(f"Received from client: {data} and put into message queue.") + except: + error_occured = True + +############################################################################################################ +# +def process_received_messages_and_send_response(message_queue, ssl_socket, logs_path): + while server_running: + if not message_queue.empty(): + message = message_queue.get() + match message: + case["stopAgent"]: + server_running = False + case["getSysInf"]: + send_system_information(ssl_socket) + case["getCon"]: + send_connections_info(ssl_socket) + case["getLogs"]: + send_logs(logs_path) + case["getNFTConf"]: + send_nftables_configuration(ssl_socket) + case _: + print(f"Unknown message: {message}.") + +################################################# +# +def send_system_information(ssl_socket): + system_info = get_system_info() + ssl_socket.send(system_info.encode('utf-8')) + +# source: https://stackoverflow.com/questions/3103178/how-to-get-the-system-info-with-python#answer-58420504 +def get_system_info(): + try: + info={} + info['platform']=platform.system() + info['platform-release']=platform.release() + info['platform-version']=platform.version() + info['architecture']=platform.machine() + info['hostname']=socket.gethostname() + info['ip-address']=socket.gethostbyname(socket.gethostname()) + info['mac-address']=':'.join(re.findall('..', '%012x' % uuid.getnode())) + info['processor']=platform.processor() + info['ram']=str(round(psutil.virtual_memory().total / (1024.0 **3)))+" GB" + return json.dumps(info) + except Exception as e: + logging.exception(e) + +################################################# + +def send_connections_info(ssl_socket): + network_connections = get_network_connections_as_string() + ssl_socket.send(network_connections.encode('utf-8')) def get_network_connections_as_string(): kinds = ['inet', 'inet4', 'inet6', 'tcp', 'tcp4', 'tcp6', 'udp', 'udp4', 'udp6', 'unix', 'all'] network_connections_as_string = "" network_connections = psutil.net_connections(kind=kinds[0]) - for conn in network_connections: - network_connections_as_string += str(conn) + "\n" + for connection in network_connections: + network_connections_as_string += str(connection) + "\n" return network_connections_as_string +################################################# +# +def send_logs(logs_path): + print("TODO") + +################################################# +# +def send_nftables_configuration(ssl_socket): + nftables_configuration = fetch_nftables_config() + ssl_socket.send(nftables_configuration.encode('utf-8')) + def fetch_nftables_config(): nft = nftables.Nftables() nft.set_json_output(True) rc,output,error = nft.cmd("list ruleset") return output -def start_server(): - host = '127.0.0.1' - port = 5000 - - # Create an SSL context - ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - ssl_context.minimum_version = ssl.TLSVersion.TLSv1_3 - ssl_context.load_cert_chain(certfile='server.crt', keyfile='server.key') - - # Create a socket and bind it to the specified address and port - server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - server_socket.bind((host, port)) - server_socket.listen(1) - - print(f"Server listening on {host}:{port}") - - while True: - # Accept incoming connections - client_socket, client_address = server_socket.accept() - print(f"Accepted connection from {client_address}") - - # Wrap the client socket with SSL using the SSL context - ssl_socket = ssl_context.wrap_socket(client_socket, server_side=True) - - try: - # Read data from the client - data = ssl_socket.recv(1024).decode('utf-8') - print(f"Received from client: {data}") - - # Create output string to send to client - output = fetch_nftables_config() - stringToSend = output - network_connections = get_network_connections_as_string() - stringToSend += network_connections - #print("Data sent to client:\n" + stringToSend) - # Send a response to the client - ssl_socket.send(stringToSend.encode('utf-8')) - - finally: - # Close the connection - ssl_socket.close() +############################################################################################################ if __name__ == "__main__": - start_server() + host_ip = get_host_ip() + port = get_port() + logs_path = get_syslog_path() + + start_server(host_ip, port, logs_path)