"""
Shared reader for the local feature flags file.

The file is written by:
- Go resident-agent FeatureFlags plugin (IM360 mode)
- Python FeatureFlagsSync plugin (AV mode)

Other subsystems (e.g. message_status_publisher) use this module
to check individual flag values at runtime.

Supported JSON shapes on disk (readers / ``is_enabled`` / ``get_params``):
- New shape ``{"flags": ["mqtt_tracking"], "params": {"flag": ["A", "B"]}}``
  (mirrors the sync API response; carries per-flag string-list params).
- Legacy object ``{"mqtt_tracking": true, ...}`` (still accepted).
- JSON array of enabled names ``["mqtt_tracking"]`` (still accepted).
- Legacy wrapper ``{"flags": ["mqtt_tracking", ...]}`` (still accepted).

The sync API checksum collapses to the legacy sorted-names array when no
params are present, so this agent and older agents agree on the bool-only
case. With params, the canonical form expands to ``{"flags": [...], "params":
{...}}`` with all keys and list members sorted.

The sync plugin also writes ``FLAGS_PLAIN_PATH`` (``/var/imunify360/feature_flags``):
plain text, one enabled flag name per line (sorted), for scripts.
"""

from __future__ import annotations

import hashlib
import json
import os
from typing import Any

FLAGS_PATH = "/var/imunify360/feature_flags.json"
# Plain list of enabled flag names (one per line), same order as sorted JSON array.
FLAGS_PLAIN_PATH = "/var/imunify360/feature_flags"

# Flag name whose params list drives MQTT message-status enrichment.
MQTT_TRACKED_METHODS_FLAG = "mqtt_tracked_methods"

_cached_flags: dict[str, Any] = {}
_cached_params: dict[str, list[str]] = {}
# Pre-built frozenset for the MQTT tracked-methods allow-list. Cached
# alongside the raw params dict so the hot path (every Reportable message
# in the_sink._call_unlocked) avoids re-allocating a fresh frozenset and
# the list copy that get_params() would do. Invalidated by the same
# file-mtime trigger that invalidates _cached_params.
_cached_mqtt_methods: frozenset[str] = frozenset()
_cached_mtime: float = 0.0


def _normalize_flags_from_file(raw: Any) -> dict[str, Any]:
    """Map file JSON to a flat name->value dict for :func:`is_enabled`."""
    if raw is None:
        return {}
    if isinstance(raw, list):
        out: dict[str, Any] = {}
        for item in raw:
            if isinstance(item, str):
                out[item] = True
        return out
    if isinstance(raw, dict):
        inner = raw.get("flags")
        if isinstance(inner, list):
            return _normalize_flags_from_file(inner)
        return raw
    return {}


def _params_from_file(raw: Any) -> dict[str, list[str]]:
    """Extract ``params`` mapping from new-shape file content.

    Only the new ``{"flags": [...], "params": {name: [...]}}`` shape carries
    params; every other (legacy) shape returns an empty mapping.
    """
    if not isinstance(raw, dict):
        return {}
    raw_params = raw.get("params")
    if not isinstance(raw_params, dict):
        return {}
    out: dict[str, list[str]] = {}
    for name, values in raw_params.items():
        if not isinstance(name, str) or not isinstance(values, list):
            continue
        cleaned = [v for v in values if isinstance(v, str)]
        if cleaned:
            out[name] = cleaned
    return out


def _read_state() -> tuple[dict[str, Any], dict[str, list[str]]]:
    global _cached_flags, _cached_params, _cached_mqtt_methods, _cached_mtime
    try:
        mtime = os.path.getmtime(FLAGS_PATH)
    except OSError:
        _cached_flags = {}
        _cached_params = {}
        _cached_mqtt_methods = frozenset()
        _cached_mtime = 0.0
        return _cached_flags, _cached_params

    if mtime == _cached_mtime:
        return _cached_flags, _cached_params

    try:
        with open(FLAGS_PATH) as f:
            raw = json.load(f)
        _cached_flags = _normalize_flags_from_file(raw)
        _cached_params = _params_from_file(raw)
    except (OSError, json.JSONDecodeError):
        _cached_flags = {}
        _cached_params = {}
    _cached_mqtt_methods = frozenset(
        _cached_params.get(MQTT_TRACKED_METHODS_FLAG, ())
    )
    _cached_mtime = mtime
    return _cached_flags, _cached_params


def _read_flags() -> dict[str, Any]:
    flags, _ = _read_state()
    return flags


def _read_params() -> dict[str, list[str]]:
    _, params = _read_state()
    return params


def enabled_flag_names_sorted(flags: Any) -> list[str]:
    """Return sorted enabled flag names for JSON and plain-text sidecar.

    Accepts the same shapes as :func:`_normalize_flags_from_file` (array,
    flat map, ``{"flags": [...]}``) so checksums and sidecars match Go
    ``enabledNamesSortedForChecksum`` / :func:`is_enabled`.
    """
    if not isinstance(flags, (list, dict)):
        raise TypeError(
            f"flags must be list or dict, not {type(flags).__name__}"
        )
    normalized = _normalize_flags_from_file(flags)
    return sorted(k for k, v in normalized.items() if v)


def canonical_sync_flag_list_bytes(names: list[str]) -> bytes:
    """JSON array bytes used for sync MD5 when no params are present
    (matches correlation_api ``checksum_for_sync_flag_list``)."""
    ordered = sorted(names)
    return json.dumps(ordered, sort_keys=True, indent=2).encode()


def canonical_sync_response_bytes(
    names: list[str], params: dict[str, list[str]]
) -> bytes:
    """JSON bytes for the sync MD5 over the full response shape.

    Mirrors correlation_api ``checksum_for_sync_response``: collapses to
    the legacy sorted-names array when ``params`` is empty so old agents
    keep matching, otherwise expands to the deterministic
    ``{"flags": [...], "params": {...}}`` form with all keys and list
    members sorted.
    """
    if not params:
        return canonical_sync_flag_list_bytes(names)
    canonical = {
        "flags": sorted(names),
        "params": {k: sorted(v) for k, v in sorted(params.items())},
    }
    return json.dumps(canonical, sort_keys=True, indent=2).encode()


def sync_checksum_hex_from_flags_file(path: str) -> str:
    """MD5 hex of the canonical sync-response form for ``path``.

    Returns "" if the file is missing or invalid. Computes the same MD5
    the server returned, so a matching checksum lets the agent skip
    the response payload on the next sync.
    """
    try:
        with open(path, encoding="utf-8") as f:
            raw = json.load(f)
    except (OSError, UnicodeDecodeError, json.JSONDecodeError):
        return ""
    names = enabled_flag_names_sorted(raw)
    params = _params_from_file(raw)
    payload = canonical_sync_response_bytes(names, params)
    return hashlib.md5(payload, usedforsecurity=False).hexdigest()


def legacy_feature_flags_map_bytes(names: list[str]) -> bytes:
    """On-disk legacy JSON: ``{flag: true, ...}`` with sorted keys."""
    d = {n: True for n in sorted({x for x in names if isinstance(x, str)})}
    return json.dumps(d, sort_keys=True, indent=2).encode()


def sync_response_file_bytes(
    names: list[str], params: dict[str, list[str]]
) -> bytes:
    """Persisted form for ``FLAGS_PATH`` carrying both flags and params.

    Same canonical shape as ``canonical_sync_response_bytes`` so the file
    is self-describing and round-trips through ``sync_checksum_hex_from_flags_file``.
    """
    canonical = {
        "flags": sorted(names),
        "params": {k: sorted(v) for k, v in sorted(params.items())},
    }
    return json.dumps(canonical, sort_keys=True, indent=2).encode()


def plain_text_payload_for_enabled_flags(flags: Any) -> bytes:
    """Body for ``FLAGS_PLAIN_PATH``: one name per line, trailing newline if non-empty."""
    names = enabled_flag_names_sorted(flags)
    if not names:
        return b""
    return ("\n".join(names) + "\n").encode()


def serialize_feature_flags_file_payload(flags: Any) -> bytes:
    """Serialize dict flags for writing ``FLAGS_PATH`` (legacy map only)."""
    if isinstance(flags, dict):
        return json.dumps(flags, sort_keys=True, indent=2).encode()
    raise TypeError(f"flags must be dict, not {type(flags).__name__}")


def is_enabled(flag_name: str, default: bool = False) -> bool:
    """Return whether *flag_name* is enabled.

    If the file is missing, unreadable, or the flag is absent,
    *default* is returned. Defaults to False so unknown flags are
    treated as disabled unless the caller explicitly opts in.
    """
    flags = _read_flags()
    value = flags.get(flag_name)
    if value is None:
        return default
    return bool(value)


def get_params(flag_name: str) -> list[str]:
    """Return the per-flag string params from the on-disk file.

    Empty list when the file is missing/unreadable, the flag is unknown,
    or the value did not come from the new structured shape (legacy
    bool-only flags carry no params by definition).
    """
    return list(_read_params().get(flag_name, ()))


def mqtt_tracked_methods() -> frozenset[str]:
    """Frozen set of method names whose status events should be enriched
    for MQTT tracing. Driven entirely by the server-side
    ``mqtt_tracked_methods`` flag's params list — the agent has no
    hard-coded list, so adding/removing tracked types is a server-side
    config change with no agent rollout.

    Cached: ``_read_state`` pre-builds the frozenset and invalidates it
    when the flags file's mtime changes. On the hot path — every
    Reportable message in ``the_sink._call_unlocked`` — this is a single
    ``os.stat`` syscall plus an identity-stable frozenset return. Two
    consecutive calls within the same mtime window return the same
    instance.
    """
    _read_state()
    return _cached_mqtt_methods
