Splitted tasks of sending and receiving messages to multiple threads. Added pattern matching and specification for received messages.

This commit is contained in:
WickedJack99
2023-12-23 14:13:41 +01:00
parent 781c148df7
commit d7a5eea805

233
agent.py
View File

@@ -1,64 +1,199 @@
import socket
import ssl
import nftables
import json
import psutil
############################################################################################################
# author Aaron Moser
############################################################################################################
#
import platform,socket,re,uuid,json,psutil,logging,ssl,nftables,queue,threading
############################################################################################################
# Boolean which indicates if the server thread should keep running or not.
server_running = False
# Create message queue for messages received from client.
receive_message_queue = queue.Queue()
############################################################################################################
#
def start_server(host_ip, port, logs_path):
server_running = True
# Create an SSL context.
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_3
ssl_context.load_cert_chain(certfile='server.crt', keyfile='server.key')
# While no severe error occured, try to keep server running and connect to upcoming
while server_running:
# Create a socket and bind it to the specified address and port.
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.bind((host_ip, port))
server_socket.listen(1)
print(f"Server listening on {host_ip}:{port}")
# Accept incoming connections.
client_socket, client_address = server_socket.accept()
print(f"Accepted connection from {client_address}")
# Wrap the client socket with SSL using the SSL context.
ssl_socket = ssl_context.wrap_socket(client_socket, server_side=True)
# Create and start thread which only puts received messages
# from client into message queue.
receive_thread = threading.Thread(
target=receive_messages_and_put_into_message_queue,
args=(receive_message_queue, ssl_socket)
)
receive_thread.start()
# Create and start thread which takes messages from receive queue
# and calls functions depending on which message was received.
response_thread = threading.Thread(
target=process_received_messages_and_send_response,
args=(receive_message_queue, ssl_socket, logs_path)
)
response_thread.start()
# Wait until both threads are finished.
receive_thread.join()
response_thread.join()
# Close the connection.
ssl_socket.close()
print("Agent has been stopped.")
############################################################################################################
#
def get_host_ip():
host = '127.0.0.1'
host_input_valid = False
while not host_input_valid:
host_input = input("Enter server ip:")
try:
host = str(host_input)
host_input_valid = True
except ValueError:
print("That is not a valid IP.")
return host
############################################################################################################
#
def get_port():
port = 5000
port_input_valid = False
while not port_input_valid:
port_input = input("Enter server port:")
try:
port = int(port_input)
port_input_valid = True
except ValueError:
print("That is not a valid number.")
return port
def get_syslog_path():
syslog_path = "/var/log/syslog"
path_input_valid = False
while not path_input_valid:
syslog_path_input = input("Enter syslog path:")
try:
syslog_path = str(syslog_path_input)
path_input_valid = True
except:
print("Invalid path.")
return syslog_path
############################################################################################################
# Read data from the client and put it into message queue.
def receive_messages_and_put_into_message_queue(message_queue, ssl_socket):
error_occured = False
while not error_occured and server_running:
try:
data = ssl_socket.recv(1024).decode('utf-8')
message_queue.put(data)
print(f"Received from client: {data} and put into message queue.")
except:
error_occured = True
############################################################################################################
#
def process_received_messages_and_send_response(message_queue, ssl_socket, logs_path):
while server_running:
if not message_queue.empty():
message = message_queue.get()
match message:
case["stopAgent"]:
server_running = False
case["getSysInf"]:
send_system_information(ssl_socket)
case["getCon"]:
send_connections_info(ssl_socket)
case["getLogs"]:
send_logs(logs_path)
case["getNFTConf"]:
send_nftables_configuration(ssl_socket)
case _:
print(f"Unknown message: {message}.")
#################################################
#
def send_system_information(ssl_socket):
system_info = get_system_info()
ssl_socket.send(system_info.encode('utf-8'))
# source: https://stackoverflow.com/questions/3103178/how-to-get-the-system-info-with-python#answer-58420504
def get_system_info():
try:
info={}
info['platform']=platform.system()
info['platform-release']=platform.release()
info['platform-version']=platform.version()
info['architecture']=platform.machine()
info['hostname']=socket.gethostname()
info['ip-address']=socket.gethostbyname(socket.gethostname())
info['mac-address']=':'.join(re.findall('..', '%012x' % uuid.getnode()))
info['processor']=platform.processor()
info['ram']=str(round(psutil.virtual_memory().total / (1024.0 **3)))+" GB"
return json.dumps(info)
except Exception as e:
logging.exception(e)
#################################################
def send_connections_info(ssl_socket):
network_connections = get_network_connections_as_string()
ssl_socket.send(network_connections.encode('utf-8'))
def get_network_connections_as_string():
kinds = ['inet', 'inet4', 'inet6', 'tcp', 'tcp4', 'tcp6', 'udp', 'udp4', 'udp6', 'unix', 'all']
network_connections_as_string = ""
network_connections = psutil.net_connections(kind=kinds[0])
for conn in network_connections:
network_connections_as_string += str(conn) + "\n"
for connection in network_connections:
network_connections_as_string += str(connection) + "\n"
return network_connections_as_string
#################################################
#
def send_logs(logs_path):
print("TODO")
#################################################
#
def send_nftables_configuration(ssl_socket):
nftables_configuration = fetch_nftables_config()
ssl_socket.send(nftables_configuration.encode('utf-8'))
def fetch_nftables_config():
nft = nftables.Nftables()
nft.set_json_output(True)
rc,output,error = nft.cmd("list ruleset")
return output
def start_server():
host = '127.0.0.1'
port = 5000
# Create an SSL context
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_3
ssl_context.load_cert_chain(certfile='server.crt', keyfile='server.key')
# Create a socket and bind it to the specified address and port
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.bind((host, port))
server_socket.listen(1)
print(f"Server listening on {host}:{port}")
while True:
# Accept incoming connections
client_socket, client_address = server_socket.accept()
print(f"Accepted connection from {client_address}")
# Wrap the client socket with SSL using the SSL context
ssl_socket = ssl_context.wrap_socket(client_socket, server_side=True)
try:
# Read data from the client
data = ssl_socket.recv(1024).decode('utf-8')
print(f"Received from client: {data}")
# Create output string to send to client
output = fetch_nftables_config()
stringToSend = output
network_connections = get_network_connections_as_string()
stringToSend += network_connections
#print("Data sent to client:\n" + stringToSend)
# Send a response to the client
ssl_socket.send(stringToSend.encode('utf-8'))
finally:
# Close the connection
ssl_socket.close()
############################################################################################################
if __name__ == "__main__":
start_server()
host_ip = get_host_ip()
port = get_port()
logs_path = get_syslog_path()
start_server(host_ip, port, logs_path)