import base64
import json
import logging
from typing import Sequence

import requests
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult
from opentelemetry.trace import SpanKind, StatusCode

logger = logging.getLogger(__name__)

_KIND_MAP = {
    SpanKind.INTERNAL: 1,
    SpanKind.SERVER: 2,
    SpanKind.CLIENT: 3,
    SpanKind.PRODUCER: 4,
    SpanKind.CONSUMER: 5,
}


def _encode_id(value: int, length: int) -> str:
    return base64.b64encode(value.to_bytes(length, "big")).decode()


def _attr_value(v):
    if isinstance(v, bool):
        return {"boolValue": v}
    if isinstance(v, int):
        return {"intValue": str(v)}
    if isinstance(v, float):
        return {"doubleValue": v}
    if isinstance(v, (list, tuple)):
        return {"arrayValue": {"values": [_attr_value(i) for i in v]}}
    return {"stringValue": str(v)}


def _fmt_attrs(attrs):
    if not attrs:
        return []
    return [{"key": k, "value": _attr_value(v)} for k, v in attrs.items()]


def _serialize(spans: Sequence[ReadableSpan]) -> dict:
    resource_map: dict = {}
    for span in spans:
        res_key = id(span.resource)
        if res_key not in resource_map:
            resource_map[res_key] = {"resource": span.resource, "scopes": {}}
        scope_key = (span.instrumentation_scope.name, span.instrumentation_scope.version)
        scopes = resource_map[res_key]["scopes"]
        if scope_key not in scopes:
            scopes[scope_key] = {"scope": span.instrumentation_scope, "spans": []}

        s = {
            "traceId": _encode_id(span.context.trace_id, 16),
            "spanId": _encode_id(span.context.span_id, 8),
            "name": span.name,
            "kind": _KIND_MAP.get(span.kind, 0),
            "startTimeUnixNano": str(span.start_time),
            "endTimeUnixNano": str(span.end_time),
            "attributes": _fmt_attrs(span.attributes),
            "events": [
                {
                    "name": e.name,
                    "timeUnixNano": str(e.timestamp),
                    "attributes": _fmt_attrs(e.attributes),
                }
                for e in span.events
            ],
            "links": [
                {
                    "traceId": _encode_id(lk.context.trace_id, 16),
                    "spanId": _encode_id(lk.context.span_id, 8),
                    "attributes": _fmt_attrs(lk.attributes),
                }
                for lk in span.links
            ],
            "status": {
                "code": 2 if span.status.status_code == StatusCode.ERROR else 0
            },
        }
        if span.parent:
            s["parentSpanId"] = _encode_id(span.parent.span_id, 8)
        scopes[scope_key]["spans"].append(s)

    resource_spans = []
    for res_data in resource_map.values():
        scope_spans = []
        for (name, version), scope_data in res_data["scopes"].items():
            scope_spans.append({
                "scope": {"name": name or "", "version": version or ""},
                "spans": scope_data["spans"],
            })
        resource_spans.append({
            "resource": {"attributes": _fmt_attrs(dict(res_data["resource"].attributes))},
            "scopeSpans": scope_spans,
        })

    return {"resourceSpans": resource_spans}


class OTLPHTTPJsonExporter(SpanExporter):
    def __init__(self, endpoint: str, headers: dict | None = None, timeout: int = 10):
        self.endpoint = endpoint
        self.headers = {"Content-Type": "application/json"}
        if headers:
            self.headers.update(headers)
        self.timeout = timeout

    def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult:
        payload = _serialize(spans)
        logger.debug("Exporting %d span(s) to %s: %s", len(spans), self.endpoint, json.dumps(payload))
        try:
            resp = requests.post(
                self.endpoint,
                data=json.dumps(payload),
                headers=self.headers,
                timeout=self.timeout,
            )
            resp.raise_for_status()
            logger.info("Successfully exported %d span(s) — HTTP %s", len(spans), resp.status_code)
            return SpanExportResult.SUCCESS
        except requests.exceptions.ConnectionError as e:
            logger.error("Connection error exporting spans to %s: %s", self.endpoint, e)
        except requests.exceptions.HTTPError as e:
            logger.error("HTTP error exporting spans: %s — response: %s", e, e.response.text if e.response else "")
        except Exception as e:
            logger.error("Unexpected error exporting spans: %s", e)
        return SpanExportResult.FAILURE

    def shutdown(self) -> None:
        pass
