import os
import sys
import json
import time
import socket
import traceback
from datetime import datetime
import importlib.util
from typing import Tuple

# Add the Lambda runtime directory to the Python path
sys.path.append('/var/runtime')

# Default server, port and protocol
DEFAULT_SERVER_ADDR = '127.0.0.1'
DEFAULT_SERVER_PORT = '9015'
DEFAULT_PROTOCOL = 'udp'

# Import the Lambda runtime client
from awslambdaric import bootstrap

def tcp_client(host, port, message):
    try:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as client_socket:
            client_socket.connect((host, port))
            client_socket.sendall(message.encode())
            client_socket.close()
    except Exception as e:
        print(f"[AppSentinels] TCP send failed: {e}")

def udp_client(host, port, message):
    try:
        with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as client_socket:
            client_socket.sendto(message.encode(), (host, port))
            client_socket.close()
    except Exception as e:
        print(f"[AppSentinels] UDP send failed: {e}")

def send_event_log(event_log):
    try:
        message = json.dumps(event_log)
        # print(f"[AppSentinels] Sending message: {message}")
        server = os.environ.get('AS_SERVER_ADDR', DEFAULT_SERVER_ADDR)
        server_port = os.environ.get('AS_SERVER_PORT', DEFAULT_SERVER_PORT)
        port = int(server_port)
        protocol = os.environ.get('AS_SERVER_PROTOCOL', DEFAULT_PROTOCOL)

        if protocol == 'tcp':
            tcp_client(server, port, message)
        else:
            udp_client(server, port, message)

    except Exception as e:
        print("[AppSentinels] Failed to send eventlog: ", e)

def check_body_size(event_log):
    MAX_BODY_SIZE = 60 * 1024
    request_body_size = len(event_log['request']['body']) if 'body' in event_log['request'] else 0
    response_body_size = len(event_log['response']['body']) if 'body' in event_log['response'] else 0

    # print(f"[AppSentinels] Request body size: {request_body_size}, Response body size: {response_body_size}")

    # Drop request or response body if the size exceeds MAX_BODY_SIZE
    if request_body_size > MAX_BODY_SIZE:
        event_log['request']['body'] = ""
        event_log['requestBodyTruncated'] = True
        # print("[AppSentinels] Request body truncated")

    if response_body_size > MAX_BODY_SIZE:
        event_log['response']['body'] = ""
        event_log['responseBodyTruncated'] = True
        # print("[AppSentinels] Response body truncated")

def parse_handler_path(handler_path: str) -> Tuple[str, str]:
    """
    Parse the handler path into module path and function name.
    Example inputs:
    - 'lambda_function.lambda_handler'
    - 'src.handlers.user_handler.lambda_handler'
    """
    try:
        module_path, handler_name = handler_path.rsplit('.', 1)
        return module_path, handler_name
    except ValueError:
        raise ValueError(f"Invalid handler path format: {handler_path}")

def find_module_file(module_path: str, app_root: str) -> Tuple[str, str]:
    """
    Find the module file in the directory structure.
    Returns tuple of (full_file_path, module_name)
    
    Examples:
    - 'lambda_function' -> (/var/task/lambda_function.py, lambda_function)
    - 'src.handlers.user_handler' -> (/var/task/src/handlers/user_handler.py, src.handlers.user_handler)
    """
    # print(f"[AppSentinels] Looking for module: {module_path}")
    
    if '.' not in module_path:
        # Single file in root directory (e.g., lambda_function.py)
        file_path = os.path.join(app_root, f"{module_path}.py")
        if os.path.isfile(file_path):
            # print(f"[AppSentinels] Found module file in root: {file_path}")
            return file_path, module_path
    else:
        # Nested path (e.g., src/handlers/user_handler.py)
        path_parts = module_path.split('.')
        file_path = os.path.join(app_root, *path_parts) + '.py'
        if os.path.isfile(file_path):
            # print(f"[AppSentinels] Found module file in nested path: {file_path}")
            return file_path, module_path
    
    # If no file is found, list directory contents to help with debugging
    print(f"[AppSentinels] Could not find module file. Directory contents:")
    for root, dirs, files in os.walk(app_root):
        print(f"\nDirectory: {root}")
        print(f"Files: {files}")
    
    raise FileNotFoundError(
        f"Could not find module file for {module_path} "
        f"in {app_root} or its subdirectories"
    )

def import_module_from_path(file_path: str, module_name: str):
    """
    Import a module from a file path using importlib.
    """
    try:
        spec = importlib.util.spec_from_file_location(module_name, file_path)
        if spec is None:
            raise ImportError(f"Failed to create spec for module: {file_path}")
            
        module = importlib.util.module_from_spec(spec)
        sys.modules[module_name] = module
        spec.loader.exec_module(module)
        return module
        
    except Exception as e:
        print(f"[AppSentinels] Error importing module {module_name} from {file_path}")
        raise ImportError(f"Failed to import module: {str(e)}")

def wrapper(event, context):
    """
    Main wrapper function that handles both root-level and nested handler files.
    """
    # Find out the original function handler
    try:
        wrapper_start = time.perf_counter()
        event_log = {'request' : event}
        event_log['request_id'] = context.aws_request_id
        
        original_handler = os.environ['_HANDLER']
        # print(f"[AppSentinels] Original handler: {original_handler}")
        
        app_root = os.getcwd()
        if app_root not in sys.path:
            sys.path.append(app_root)
        # print(f"[AppSentinels] app_root: {app_root}")

         # Parse handler path
        module_path, handler_name = parse_handler_path(original_handler)
        # print(f"[AppSentinels] Module path: {module_path}, Handler name: {handler_name}")
        
        # Find and import the module
        file_path, module_name = find_module_file(module_path, app_root)
        module = import_module_from_path(file_path, module_name)
        
        # Get the handler function
        if not hasattr(module, handler_name):
            raise AttributeError(
                f"Module '{module_name}' has no function named '{handler_name}'. "
                f"Available attributes: {dir(module)}"
            )

        original_handler_func = getattr(module, handler_name)
        # print(f"[AppSentinels] original_handler_func: {original_handler_func}")
        
        start_time = time.perf_counter()
    except Exception as e:
        traceback.print_exc()
        raise

    # Call original function handler
    try:
        response = original_handler_func(event, context)
    except Exception as e:
        print('[AppSentinels] Exception in original functional handler')
        traceback.print_exc()
        raise

    # Capture and send the log
    try:
        end_time = time.perf_counter()
        # time.perf_counter() function returns in fractional seconds
        latency = (end_time - start_time) * 1000 # milliseconds

        # Generate a message for external extension
        # TODO: Add an option to compress the payload here.
        event_log['response'] = response
        event_log['latency'] = latency
        event_log['requestBodyTruncated'] = False
        event_log['responseBodyTruncated'] = False
        check_body_size(event_log)
        send_event_log(event_log)

        wrapper_end = time.perf_counter()
        total_latency = (wrapper_end - wrapper_start) * 1000

        # print(f"[AppSentinels] Total latency is: {total_latency}")

        return response
    except Exception as e:
        traceback.print_exc()
        raise

if __name__ == "__main__":
    # Get the handler and API address from environment variables
    runtime_api_address = os.environ['AWS_LAMBDA_RUNTIME_API']
    app_root = os.getcwd()
    
    # Run the bootstrap with our wrapper
    bootstrap.run("app_root", f"{__name__}.wrapper", runtime_api_address)