import asyncio
import datetime
import pwd
import re
import subprocess
import urllib.request
import os

from logging import getLogger
from urllib.error import URLError
from pathlib import Path

from defence360agent.utils import atomic_rewrite

logger = getLogger(__name__)

ANALYST_PUB_KEY_URL = (
    "https://repo.imunify360.cloudlinux.com/defense360/assisted-cleanup.pub"
)
KEY_PATTERN = r"clsupport@sshbox\.cloudlinux\.com"
SSH_CONFIG_PATH = Path("/etc/ssh/sshd_config")
SSH_CONFIG_DIR = Path("/etc/ssh/sshd_config.d")

# \Z (not $) — $ would accept a trailing newline.
_USERNAME_RE = re.compile(r"^[a-z_][a-z0-9_-]{0,31}\Z")


def _resolve_authorized_keys(username: str) -> Path:
    """Home dir via pwd.getpwnam, not /home/ concatenation, to block path traversal."""
    if not isinstance(username, str) or not _USERNAME_RE.match(username):
        raise ValueError("invalid username: %r" % (username,))
    if username == "root":
        return Path("/root/.ssh/authorized_keys")
    try:
        home = pwd.getpwnam(username).pw_dir
    except KeyError as e:
        raise ValueError("no such user: %r" % (username,)) from e
    # pwd.pw_dir is normally absolute, but panel-driven user creation can
    # leave it empty or relative; refuse rather than write under CWD.
    if not home or not os.path.isabs(home):
        raise ValueError(
            "non-absolute home directory for %r: %r" % (username, home)
        )
    return Path(os.path.join(home, ".ssh", "authorized_keys"))


# The support pub key is shared across every Imunify install, so a leaked
# private counterpart would grant root on the whole fleet. Bound the blast
# radius via restrict + expiry-time options on the authorized_keys line.
DEFAULT_KEY_TTL_DAYS = 7
KEY_TTL_ENV_VAR = "IMUNIFY_ASSISTED_CLEANUP_KEY_TTL_DAYS"
KEY_OPTIONS_BASE = "restrict,pty"


async def get_ssh_port():
    """
    Detect SSH port from config and its overrides.
    Searches configs in reverse order to find the last override first.
    """
    port = 22  # default port
    try:
        # Collect and sort config files
        config_files = [SSH_CONFIG_PATH]

        if SSH_CONFIG_DIR.exists():
            config_files.extend(sorted(SSH_CONFIG_DIR.glob("*.conf")))

        # Process files
        for config_file in reversed(config_files):
            try:
                for line in config_file.read_text().splitlines():
                    line = line.strip()
                    if line.startswith("Port ") and not line.startswith("#"):
                        try:
                            # return first match
                            # since we are searching backwards
                            port = int(line.split()[1])
                            return port
                        except (IndexError, ValueError):
                            continue
            except IOError as e:
                logger.warning(f"Failed to read {config_file}: {e}")
                continue
    except Exception as e:
        logger.warning(f"Failed to get SSH port: {e}")
    finally:
        return port


async def check_ssh_connection(port=22):
    """Test if port is actually an SSH port by checking the server banner"""
    try:
        reader, writer = await asyncio.open_connection("127.0.0.1", port)
        try:
            banner = await asyncio.wait_for(reader.readline(), timeout=5.0)
            banner = banner.decode("utf-8", errors="ignore").strip()

            if re.match(r"^SSH-[12]\.", banner):
                logger.info(
                    f"Port {port} is confirmed as SSH (banner: {banner})"
                )
                return True
            else:
                logger.warning(
                    f"Port {port} is open but not SSH (got: {banner})"
                )
                return False

        except asyncio.TimeoutError:
            logger.warning(f"Timeout waiting for SSH banner on port {port}")
            return False
        finally:
            writer.close()
            await writer.wait_closed()

    except (ConnectionRefusedError, OSError) as e:
        logger.warning(f"Failed to connect to port {port}: {e}")
        return False
    except Exception as e:
        logger.warning(f"Unexpected error checking SSH port {port}: {e}")
        return False


def _key_ttl_days() -> int:
    """Read the assisted-cleanup key TTL from env, falling back to default."""
    raw = os.environ.get(KEY_TTL_ENV_VAR, "")
    try:
        ttl = int(raw)
        if ttl > 0:
            return ttl
    except (TypeError, ValueError):
        pass
    return DEFAULT_KEY_TTL_DAYS


def _expiry_timestamp(now: "datetime.datetime | None" = None) -> str:
    # Bare timestamp (no Z): Z requires OpenSSH >= 9.1; without it sshd
    # parses as local time per authorized_keys(5), so convert before format.
    base = now or datetime.datetime.now(datetime.timezone.utc)
    expiry = base.astimezone() + datetime.timedelta(days=_key_ttl_days())
    return expiry.strftime("%Y%m%d%H%M")


_OPENSSH_VERSION_RE = re.compile(r"OpenSSH_(\d+)\.(\d+)")


async def _sshd_supports_expiry_time() -> bool:
    # expiry-time keyword exists since OpenSSH 7.7; older sshd (CL7) rejects
    # the whole line. Probe failure -> False so we fall back to restrict,pty.
    try:
        proc = await asyncio.create_subprocess_exec(
            "ssh",
            "-V",
            stdout=asyncio.subprocess.PIPE,
            stderr=asyncio.subprocess.PIPE,
        )
        stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=5)
    except (OSError, asyncio.TimeoutError) as e:
        logger.warning("ssh -V probe failed: %s", e)
        return False
    output = (stderr or b"").decode("utf-8", errors="ignore") or (
        stdout or b""
    ).decode("utf-8", errors="ignore")
    match = _OPENSSH_VERSION_RE.search(output)
    if not match:
        logger.warning(
            "ssh -V did not match OpenSSH version pattern: %r", output[:200]
        )
        return False
    major, minor = int(match.group(1)), int(match.group(2))
    return (major, minor) >= (7, 7)


def build_authorized_key_line(pub_key: str, *, supports_expiry: bool) -> str:
    if supports_expiry:
        options = f'{KEY_OPTIONS_BASE},expiry-time="{_expiry_timestamp()}"'
    else:
        options = KEY_OPTIONS_BASE
    return f"{options} {pub_key.strip()}"


def _target_uid_gid(username: str):
    """Resolve uid/gid for the target user, or (None, None) when not applicable.

    Returning ``(None, None)`` for root or unknown users lets
    ``atomic_rewrite`` skip its chown step and preserve the existing
    file's ownership.
    """
    if username == "root":
        return None, None
    try:
        pw = pwd.getpwnam(username)
    except KeyError:
        logger.warning(
            "user %r not found; leaving authorized_keys ownership untouched",
            username,
        )
        return None, None
    return pw.pw_uid, pw.pw_gid


async def install_pub_key(username="root"):
    # Idempotent: re-running rotates the expiry and replaces any legacy
    # (unguarded or older guarded) copy of the same key.
    try:
        try:
            auth_keys_path = _resolve_authorized_keys(username)
        except ValueError as e:
            logger.error("install_pub_key: %s", e)
            return False

        # If not running as root, fail
        if os.geteuid() != 0:
            logger.error("Function must be run as root")
            return False

        # Download the public key
        try:
            pub_key = (
                urllib.request.urlopen(ANALYST_PUB_KEY_URL)
                .read()
                .decode()
                .strip()
            )
        except URLError as e:
            logger.error(f"Failed to download public key: {e}")
            return False

        # A genuine key is single-line; an embedded newline would split into
        # a second, option-less authorized_keys entry that bypasses restrict.
        if "\n" in pub_key or "\r" in pub_key:
            logger.error("Downloaded public key spans multiple lines")
            return False

        # Check if the authorized_keys directory exists, create if not
        auth_keys_dir = auth_keys_path.parent
        if not auth_keys_dir.exists():
            try:
                auth_keys_dir.mkdir(mode=0o700, parents=True, exist_ok=True)
                # Set proper ownership for the .ssh directory
                if username != "root":
                    subprocess.run(
                        ["chown", f"{username}:{username}", str(auth_keys_dir)]
                    )
            except Exception as e:
                logger.error(
                    f"Failed to create directory {auth_keys_dir}: {e}"
                )
                return False

        # Check if the authorized_keys file exists, create if not
        if not auth_keys_path.exists():
            try:
                auth_keys_path.touch(mode=0o600)
                # Set proper ownership for the authorized_keys file
                if username != "root":
                    subprocess.run(
                        [
                            "chown",
                            f"{username}:{username}",
                            str(auth_keys_path),
                        ]
                    )
            except Exception as e:
                logger.error(f"Failed to create file {auth_keys_path}: {e}")
                return False

        try:
            guarded_line = build_authorized_key_line(
                pub_key,
                supports_expiry=await _sshd_supports_expiry_time(),
            )

            # Read existing content; strip any prior copy of the support
            # key (legacy unguarded or older guarded line) so re-running
            # rotates options + expiry instead of stacking duplicates.
            existing = auth_keys_path.read_text()
            stripped = re.sub(
                r".*" + KEY_PATTERN + r".*\n?",
                "",
                existing,
            )
            new_content = stripped
            if new_content and not new_content.endswith("\n"):
                new_content += "\n"
            new_content += guarded_line + "\n"

            uid, gid = _target_uid_gid(username)
            atomic_rewrite(
                auth_keys_path,
                new_content,
                backup=False,
                uid=uid,
                gid=gid,
            )
            logger.info(
                "Installed assisted-cleanup key for user %s (%s)",
                username,
                guarded_line.split(" ", 1)[0],
            )
            return True
        except IOError as e:
            logger.error(f"Failed to write to {auth_keys_path}: {e}")
            return False
    except Exception as e:
        logger.error(f"Failed to install public key: {e}")
        return False


def remove_pub_key(username="root") -> bool:
    """Remove analyst public key for the specified user

    This function removes the analyst's public key that was previously
    installed using the install_pub_key function.
    returns: True if key was successfully removed, False otherwise.
    """
    try:
        try:
            auth_keys_path = _resolve_authorized_keys(username)
        except ValueError as e:
            logger.error("remove_pub_key: %s", e)
            return False

        # Check if the file exists
        if not auth_keys_path.exists():
            logger.warning(
                f"authorized_keys file not found at {auth_keys_path}"
            )
            return False

        # Read the current content of the file
        try:
            content = auth_keys_path.read_text()
        except IOError as e:
            logger.error(f"Failed to read {auth_keys_path}: {e}")
            return False

        # Check if the key exists in the file
        if not re.search(KEY_PATTERN, content):
            logger.info(f"Analyst public key not found in {auth_keys_path}")
            return False

        # Remove the key (including the line it's on)
        new_content = re.sub(r".*" + KEY_PATTERN + r".*\n?", "", content)

        # If the file ends up empty, consider adding a note
        if not new_content.strip():
            logger.info(f"File {auth_keys_path} will be empty after removal")

        # Write the updated content back to the file
        try:
            uid, gid = _target_uid_gid(username)
            atomic_rewrite(
                auth_keys_path,
                new_content,
                backup=True,
                uid=uid,
                gid=gid,
            )
            logger.info(
                "Successfully removed analyst public key from"
                f" {auth_keys_path}"
            )
            return True
        except IOError as e:
            logger.error(f"Failed to write to {auth_keys_path}: {e}")
            return False

    except Exception as e:
        logger.error(f"Failed to remove public key: {e}")
        return False
