import os
import sys
import json
import time
import socket
import traceback
from datetime import datetime

# Add the Lambda runtime directory to the Python path
sys.path.append('/var/runtime')

# Default server, port and protocol
DEFAULT_SERVER_ADDR = '127.0.0.1'
DEFAULT_SERVER_PORT = '9015'
DEFAULT_PROTOCOL = 'udp'

# Import the Lambda runtime client
from awslambdaric import bootstrap

def tcp_client(host, port, message):
    try:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as client_socket:
            client_socket.connect((host, port))
            client_socket.sendall(message.encode())
            client_socket.close()
    except Exception as e:
        print(f"[AppSentinels] TCP send failed: {e}")

def udp_client(host, port, message):
    try:
        with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as client_socket:
            client_socket.sendto(message.encode(), (host, port))
            client_socket.close()
    except Exception as e:
        print(f"[AppSentinels] UDP send failed: {e}")

def send_event_log(event_log):
    try:
        message = json.dumps(event_log)
        # print(f"[AppSentinels] Sending message: {message}")
        server = os.environ.get('AS_SERVER_ADDR', DEFAULT_SERVER_ADDR)
        server_port = os.environ.get('AS_SERVER_PORT', DEFAULT_SERVER_PORT)
        port = int(server_port)
        protocol = os.environ.get('AS_SERVER_PROTOCOL', DEFAULT_PROTOCOL)

        if protocol == 'tcp':
            tcp_client(server, port, message)
        else:
            udp_client(server, port, message)

    except Exception as e:
        print("[AppSentinels] Failed to send eventlog: ", e)

def check_body_size(event_log):
    MAX_BODY_SIZE = 60 * 1024
    request_body_size = len(event_log['request']['body']) if 'body' in event_log['request'] else 0
    response_body_size = len(event_log['response']['body']) if 'body' in event_log['response'] else 0

    # print(f"[AppSentinels] Request body size: {request_body_size}, Response body size: {response_body_size}")

    # Drop request or response body if the size exceeds MAX_BODY_SIZE
    if request_body_size > MAX_BODY_SIZE:
        event_log['request']['body'] = ""
        event_log['requestBodyTruncated'] = True
        # print("[AppSentinels] Request body truncated")

    if response_body_size > MAX_BODY_SIZE:
        event_log['response']['body'] = ""
        event_log['responseBodyTruncated'] = True
        # print("[AppSentinels] Response body truncated")

def wrapper(event, context):
    # Find out the original function handler
    try:
        wrapper_start = time.perf_counter()
        event_log = {'request' : event}
        event_log['request_id'] = context.aws_request_id
        
        original_handler = os.environ['_HANDLER']
        # print(f"[AppSentinels] Original handler: {original_handler}")
        
        app_root = os.getcwd()
        if app_root not in sys.path:
            sys.path.append(app_root)

        module_name, handler_name = original_handler.rsplit('.', 1)
        
        module = __import__(module_name)
        original_handler_func = getattr(module, handler_name)
        
        start_time = time.perf_counter()
    except Exception as e:
        traceback.print_exc()
        raise

    # Call original function handler
    try:
        response = original_handler_func(event, context)
    except Exception as e:
        print('[AppSentinels] Exception in original functional handler')
        traceback.print_exc()
        raise

    # Capture and send the log
    try:
        end_time = time.perf_counter()
        # time.perf_counter() function returns in fractional seconds
        latency = (end_time - start_time) * 1000 # milliseconds

        # Generate a message for external extension
        # TODO: Add an option to compress the payload here.
        event_log['response'] = response
        event_log['latency'] = latency
        event_log['requestBodyTruncated'] = False
        event_log['responseBodyTruncated'] = False
        check_body_size(event_log)
        send_event_log(event_log)

        wrapper_end = time.perf_counter()
        total_latency = (wrapper_end - wrapper_start) * 1000

        # print(f"[AppSentinels] Total latency is: {total_latency}")

        return response
    except Exception as e:
        traceback.print_exc()
        raise

if __name__ == "__main__":
    # Get the handler and API address from environment variables
    runtime_api_address = os.environ['AWS_LAMBDA_RUNTIME_API']
    app_root = os.getcwd()
    
    # Run the bootstrap with our wrapper
    bootstrap.run("app_root", f"{__name__}.wrapper", runtime_api_address)