Splitted tasks of sending and receiving messages to multiple threads. Added pattern matching and specification for received messages.
This commit is contained in:
233
agent.py
233
agent.py
@@ -1,64 +1,199 @@
|
||||
import socket
|
||||
import ssl
|
||||
import nftables
|
||||
import json
|
||||
import psutil
|
||||
############################################################################################################
|
||||
# author Aaron Moser
|
||||
|
||||
############################################################################################################
|
||||
#
|
||||
import platform,socket,re,uuid,json,psutil,logging,ssl,nftables,queue,threading
|
||||
|
||||
############################################################################################################
|
||||
# Boolean which indicates if the server thread should keep running or not.
|
||||
server_running = False
|
||||
|
||||
# Create message queue for messages received from client.
|
||||
receive_message_queue = queue.Queue()
|
||||
|
||||
############################################################################################################
|
||||
#
|
||||
def start_server(host_ip, port, logs_path):
|
||||
server_running = True
|
||||
|
||||
# Create an SSL context.
|
||||
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_3
|
||||
ssl_context.load_cert_chain(certfile='server.crt', keyfile='server.key')
|
||||
|
||||
# While no severe error occured, try to keep server running and connect to upcoming
|
||||
while server_running:
|
||||
# Create a socket and bind it to the specified address and port.
|
||||
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
server_socket.bind((host_ip, port))
|
||||
server_socket.listen(1)
|
||||
print(f"Server listening on {host_ip}:{port}")
|
||||
|
||||
# Accept incoming connections.
|
||||
client_socket, client_address = server_socket.accept()
|
||||
print(f"Accepted connection from {client_address}")
|
||||
|
||||
# Wrap the client socket with SSL using the SSL context.
|
||||
ssl_socket = ssl_context.wrap_socket(client_socket, server_side=True)
|
||||
|
||||
# Create and start thread which only puts received messages
|
||||
# from client into message queue.
|
||||
receive_thread = threading.Thread(
|
||||
target=receive_messages_and_put_into_message_queue,
|
||||
args=(receive_message_queue, ssl_socket)
|
||||
)
|
||||
receive_thread.start()
|
||||
|
||||
# Create and start thread which takes messages from receive queue
|
||||
# and calls functions depending on which message was received.
|
||||
response_thread = threading.Thread(
|
||||
target=process_received_messages_and_send_response,
|
||||
args=(receive_message_queue, ssl_socket, logs_path)
|
||||
)
|
||||
response_thread.start()
|
||||
|
||||
# Wait until both threads are finished.
|
||||
receive_thread.join()
|
||||
response_thread.join()
|
||||
|
||||
# Close the connection.
|
||||
ssl_socket.close()
|
||||
|
||||
print("Agent has been stopped.")
|
||||
|
||||
############################################################################################################
|
||||
#
|
||||
def get_host_ip():
|
||||
host = '127.0.0.1'
|
||||
host_input_valid = False
|
||||
|
||||
while not host_input_valid:
|
||||
host_input = input("Enter server ip:")
|
||||
try:
|
||||
host = str(host_input)
|
||||
host_input_valid = True
|
||||
except ValueError:
|
||||
print("That is not a valid IP.")
|
||||
return host
|
||||
|
||||
############################################################################################################
|
||||
#
|
||||
def get_port():
|
||||
port = 5000
|
||||
port_input_valid = False
|
||||
while not port_input_valid:
|
||||
port_input = input("Enter server port:")
|
||||
try:
|
||||
port = int(port_input)
|
||||
port_input_valid = True
|
||||
except ValueError:
|
||||
print("That is not a valid number.")
|
||||
return port
|
||||
|
||||
def get_syslog_path():
|
||||
syslog_path = "/var/log/syslog"
|
||||
path_input_valid = False
|
||||
while not path_input_valid:
|
||||
syslog_path_input = input("Enter syslog path:")
|
||||
try:
|
||||
syslog_path = str(syslog_path_input)
|
||||
path_input_valid = True
|
||||
except:
|
||||
print("Invalid path.")
|
||||
return syslog_path
|
||||
|
||||
############################################################################################################
|
||||
# Read data from the client and put it into message queue.
|
||||
def receive_messages_and_put_into_message_queue(message_queue, ssl_socket):
|
||||
error_occured = False
|
||||
while not error_occured and server_running:
|
||||
try:
|
||||
data = ssl_socket.recv(1024).decode('utf-8')
|
||||
message_queue.put(data)
|
||||
print(f"Received from client: {data} and put into message queue.")
|
||||
except:
|
||||
error_occured = True
|
||||
|
||||
############################################################################################################
|
||||
#
|
||||
def process_received_messages_and_send_response(message_queue, ssl_socket, logs_path):
|
||||
while server_running:
|
||||
if not message_queue.empty():
|
||||
message = message_queue.get()
|
||||
match message:
|
||||
case["stopAgent"]:
|
||||
server_running = False
|
||||
case["getSysInf"]:
|
||||
send_system_information(ssl_socket)
|
||||
case["getCon"]:
|
||||
send_connections_info(ssl_socket)
|
||||
case["getLogs"]:
|
||||
send_logs(logs_path)
|
||||
case["getNFTConf"]:
|
||||
send_nftables_configuration(ssl_socket)
|
||||
case _:
|
||||
print(f"Unknown message: {message}.")
|
||||
|
||||
#################################################
|
||||
#
|
||||
def send_system_information(ssl_socket):
|
||||
system_info = get_system_info()
|
||||
ssl_socket.send(system_info.encode('utf-8'))
|
||||
|
||||
# source: https://stackoverflow.com/questions/3103178/how-to-get-the-system-info-with-python#answer-58420504
|
||||
def get_system_info():
|
||||
try:
|
||||
info={}
|
||||
info['platform']=platform.system()
|
||||
info['platform-release']=platform.release()
|
||||
info['platform-version']=platform.version()
|
||||
info['architecture']=platform.machine()
|
||||
info['hostname']=socket.gethostname()
|
||||
info['ip-address']=socket.gethostbyname(socket.gethostname())
|
||||
info['mac-address']=':'.join(re.findall('..', '%012x' % uuid.getnode()))
|
||||
info['processor']=platform.processor()
|
||||
info['ram']=str(round(psutil.virtual_memory().total / (1024.0 **3)))+" GB"
|
||||
return json.dumps(info)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
#################################################
|
||||
|
||||
def send_connections_info(ssl_socket):
|
||||
network_connections = get_network_connections_as_string()
|
||||
ssl_socket.send(network_connections.encode('utf-8'))
|
||||
|
||||
def get_network_connections_as_string():
|
||||
kinds = ['inet', 'inet4', 'inet6', 'tcp', 'tcp4', 'tcp6', 'udp', 'udp4', 'udp6', 'unix', 'all']
|
||||
network_connections_as_string = ""
|
||||
network_connections = psutil.net_connections(kind=kinds[0])
|
||||
for conn in network_connections:
|
||||
network_connections_as_string += str(conn) + "\n"
|
||||
for connection in network_connections:
|
||||
network_connections_as_string += str(connection) + "\n"
|
||||
return network_connections_as_string
|
||||
|
||||
#################################################
|
||||
#
|
||||
def send_logs(logs_path):
|
||||
print("TODO")
|
||||
|
||||
#################################################
|
||||
#
|
||||
def send_nftables_configuration(ssl_socket):
|
||||
nftables_configuration = fetch_nftables_config()
|
||||
ssl_socket.send(nftables_configuration.encode('utf-8'))
|
||||
|
||||
def fetch_nftables_config():
|
||||
nft = nftables.Nftables()
|
||||
nft.set_json_output(True)
|
||||
rc,output,error = nft.cmd("list ruleset")
|
||||
return output
|
||||
|
||||
def start_server():
|
||||
host = '127.0.0.1'
|
||||
port = 5000
|
||||
|
||||
# Create an SSL context
|
||||
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_3
|
||||
ssl_context.load_cert_chain(certfile='server.crt', keyfile='server.key')
|
||||
|
||||
# Create a socket and bind it to the specified address and port
|
||||
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
server_socket.bind((host, port))
|
||||
server_socket.listen(1)
|
||||
|
||||
print(f"Server listening on {host}:{port}")
|
||||
|
||||
while True:
|
||||
# Accept incoming connections
|
||||
client_socket, client_address = server_socket.accept()
|
||||
print(f"Accepted connection from {client_address}")
|
||||
|
||||
# Wrap the client socket with SSL using the SSL context
|
||||
ssl_socket = ssl_context.wrap_socket(client_socket, server_side=True)
|
||||
|
||||
try:
|
||||
# Read data from the client
|
||||
data = ssl_socket.recv(1024).decode('utf-8')
|
||||
print(f"Received from client: {data}")
|
||||
|
||||
# Create output string to send to client
|
||||
output = fetch_nftables_config()
|
||||
stringToSend = output
|
||||
network_connections = get_network_connections_as_string()
|
||||
stringToSend += network_connections
|
||||
#print("Data sent to client:\n" + stringToSend)
|
||||
# Send a response to the client
|
||||
ssl_socket.send(stringToSend.encode('utf-8'))
|
||||
|
||||
finally:
|
||||
# Close the connection
|
||||
ssl_socket.close()
|
||||
############################################################################################################
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_server()
|
||||
host_ip = get_host_ip()
|
||||
port = get_port()
|
||||
logs_path = get_syslog_path()
|
||||
|
||||
start_server(host_ip, port, logs_path)
|
||||
|
||||
Reference in New Issue
Block a user