import logging
import os
from collections import defaultdict

from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from opentelemetry import trace
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter, SimpleSpanProcessor
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from otlp_json_exporter import OTLPHTTPJsonExporter

logging.basicConfig(level=logging.DEBUG, format="%(asctime)s %(levelname)s %(name)s — %(message)s")

# Resource picks up OTEL_SERVICE_NAME automatically
resource = Resource.create()
provider = TracerProvider(resource=resource)

# Console exporter: prints every span to container stdout for local visibility
provider.add_span_processor(SimpleSpanProcessor(ConsoleSpanExporter()))

# Remote exporter: sends spans to the OTLP/HTTP JSON endpoint
provider.add_span_processor(
    BatchSpanProcessor(
        OTLPHTTPJsonExporter(endpoint=os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"])
    )
)
trace.set_tracer_provider(provider)

# Hypertrace default body capture limit is 128KB
_MAX_BODY_BYTES = 131_072


class BodyCaptureMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        span = trace.get_current_span()

        req_header_map = defaultdict(list)
        for key, value in request.headers.items():
            req_header_map[key.lower()].append(value)
        for key, values in req_header_map.items():
            span.set_attribute(f"http.request.header.{key}", values if len(values) > 1 else values[0])

        req_body = await request.body()
        if req_body:
            span.set_attribute(
                "http.request.body",
                req_body[:_MAX_BODY_BYTES].decode("utf-8", errors="replace"),
            )

        response = await call_next(request)

        chunks = []
        async for chunk in response.body_iterator:
            chunks.append(chunk)
        resp_body = b"".join(chunks)

        resp_header_map = defaultdict(list)
        for key, value in response.headers.items():
            resp_header_map[key.lower()].append(value)
        for key, values in resp_header_map.items():
            span.set_attribute(f"http.response.header.{key}", values if len(values) > 1 else values[0])
        if resp_body:
            span.set_attribute(
                "http.response.body",
                resp_body[:_MAX_BODY_BYTES].decode("utf-8", errors="replace"),
            )

        return Response(
            content=resp_body,
            status_code=response.status_code,
            headers=dict(response.headers),
            media_type=response.media_type,
        )


app = FastAPI()
# BodyCaptureMiddleware must be added BEFORE instrument_app — instrument_app inserts
# OTel middleware at the outermost position, which means OTel creates the span first
# and BodyCaptureMiddleware runs inside it with a valid span context.
app.add_middleware(BodyCaptureMiddleware)
FastAPIInstrumentor().instrument_app(app)

tracer = trace.get_tracer(__name__)


@app.get("/testjson")
async def get_test(request: Request):
    with tracer.start_as_current_span("get_test"):
        return {
            "message": "Hello from OpenTelemetry-instrumented FastAPI!",
            "method": request.method,
            "path": request.url.path,
        }


@app.post("/testjson")
async def post_test(request: Request):
    with tracer.start_as_current_span("post_test"):
        body = await request.body()
        return {
            "message": "Hello from OpenTelemetry-instrumented FastAPI!",
            "method": request.method,
            "path": request.url.path,
            "body": body.decode(),
        }
