#!/usr/bin/env python3
"""Check paren balance in #+begin_src lisp blocks of .org files.

Uses SBCL's actual reader for 100% accuracy.

Usage: check-parens <file.org> [<file.org> ...]
       check-parens -v <file.org>

Exit 0 if all blocks balanced and terminated, 1 otherwise.
"""

import os
import sys
import re
import subprocess
import tempfile


def check_file(path, verbose):
    lines = read_lines(path)
    if lines is None:
        return False

    blocks = extract_blocks(lines)
    ok = True

    for start, body in blocks:
        if not body:
            continue

        if is_reader_error(body):
            print(f"{path}: Block at line {start}: {is_reader_error(body)}")
            if verbose:
                for line in body:
                    print(f"  | {line}")
            ok = False

    return ok


def read_lines(path):
    try:
        with open(path, encoding="utf-8") as f:
            return f.readlines()
    except FileNotFoundError:
        print(f"{path}: file not found", file=sys.stderr)
        return None
    except Exception as e:
        print(f"{path}: error reading file — {e}", file=sys.stderr)
        return None


LISP_BEGIN = re.compile(r"#\+begin_src\s+lisp\b", re.IGNORECASE)
END_SRC = re.compile(r"#\+end_src\b", re.IGNORECASE)


def extract_blocks(lines):
    blocks = []
    start = None
    buf = None

    for i, line in enumerate(lines, start=1):
        if start is None:
            if LISP_BEGIN.match(line.lstrip()):
                start = i
                buf = []
        else:
            if END_SRC.match(line.lstrip()):
                blocks.append((start, buf))
                start = None
                buf = None
            else:
                buf.append(line.rstrip("\n"))

    if start is not None:
        blocks.append((start, buf))

    return blocks


SBCL = "/usr/bin/sbcl"
CHECKER_LISP = "/tmp/check-parens-reader.lisp"

# One-time setup: write the checker lisp module
CHECKER_SRC = r"""(in-package :cl-user)
(defpackage :cp-check (:use :cl))
(in-package :cp-check)
(defun read-file (path)
  (with-open-file (s path :external-format :utf-8)
    (let ((buf (make-string (file-length s))))
      (read-sequence buf s)
      buf)))
(defun check (path)
  (handler-case
      (let* ((str (read-file path))
             (end (length str))
             (pos 0))
        (loop
          (multiple-value-bind (form new-pos)
              (read-from-string str nil nil :start pos)
            (when (null form)
              (return :OK))
            (setf pos new-pos))))
    (sb-int:simple-reader-package-error (c)
      (declare (ignore c))
      :PACKAGE-ERROR)
    (sb-int:simple-reader-error (c)
      (format nil "READER-ERROR: ~a" c))
    (end-of-file (c)
      (format nil "EOF: ~a" c))
    (error (c)
      (declare (ignore c))
      :OTHER-ERROR)))
"""

if not os.path.exists(CHECKER_LISP):
    with open(CHECKER_LISP, "w") as f:
        f.write(CHECKER_SRC)


def parse_result(output):
    """Parse SBCL output to determine if there's a paren error."""
    output = output.strip()
    if not output:
        return None

    # SBCL may print warnings before the result on separate lines.
    # Find the last non-style-warning line that contains our result token.
    for line in reversed(output.split("\n")):
        line = line.strip()
        if line.startswith(";") or line.startswith("#<"):
            continue
        if line == ":OK" or line.endswith(":OK"):
            return None
        if line.startswith(":PACKAGE-ERROR") or line.endswith(":PACKAGE-ERROR"):
            return None
        if line.startswith(":OTHER-ERROR") or line.endswith(":OTHER-ERROR"):
            return "unbalanced parentheses (unknown error)"
        if "READER-ERROR:" in line:
            msg = line.split("READER-ERROR:", 1)[1].strip()
            if "unmatched close parenthesis" in msg:
                return "unbalanced (extra close parenthesis)"
            return f"unbalanced ({msg[:60]})"
        if "EOF:" in line:
            return "unbalanced (missing close parenthesis)"

    return None


def is_reader_error(code_lines):
    """Feed code to SBCL's reader via temp file. Returns error string or None."""
    code = "\n".join(code_lines)
    if not code.strip():
        return None

    if not os.path.exists(SBCL):
        return f"SBCL not found at {SBCL}"

    with tempfile.NamedTemporaryFile(mode="w", suffix=".lisp", delete=False) as f:
        f.write(code)
        temp_path = f.name

    try:
        # Use --no-userinit and --disable-debugger to suppress all interactive output
        result = subprocess.run(
            [SBCL, "--noinform", "--no-userinit", "--disable-debugger",
             "--quit", "--load", CHECKER_LISP,
             "--eval", f'(print (cp-check::check "{temp_path}"))'],
            capture_output=True, text=True, timeout=10
        )
        return parse_result(result.stdout)
    except subprocess.TimeoutExpired:
        return "TIMEOUT (sbcl hung)"
    finally:
        try:
            os.unlink(temp_path)
        except OSError:
            pass


def main():
    verbose = False
    files = []
    for arg in sys.argv[1:]:
        if arg == "-v" or arg == "--verbose":
            verbose = True
        elif arg.startswith("-"):
            print(f"Usage: {sys.argv[0]} [-v] <file.org> [...]", file=sys.stderr)
            return 2
        else:
            files.append(arg)

    if not files:
        print(f"Usage: {sys.argv[0]} [-v] <file.org> [...]", file=sys.stderr)
        return 2

    all_ok = True
    for path in files:
        if not check_file(path, verbose):
            all_ok = False

    return 0 if all_ok else 1


if __name__ == "__main__":
    sys.exit(main())
