#!/usr/bin/env python3
"""Extract a #+begin_src lisp block from an .org file and print it.

Identify the block by:
  - index: --block 3  (1-based, counting all #+begin_src lisp blocks)
  - function: --function view-status  (finds a defun/demacro matching the name)

Output is the block content between begin and end markers, sans markers.

Usage:
  repl-block org/file.org --function foo | repl
  repl-block org/file.org --block 3
  repl-block org/file.org --function foo --package :my-package  (adds in-package prefix)
"""

import sys
import re
import argparse


LISP_BEGIN = re.compile(r"#\+begin_src\s+lisp\b", re.IGNORECASE)
END_SRC = re.compile(r"#\+end_src\b", re.IGNORECASE)


def extract_blocks(lines):
    blocks = []
    start = None
    buf = None
    for i, line in enumerate(lines, start=1):
        if start is None:
            if LISP_BEGIN.match(line.lstrip()):
                start = i
                buf = []
        else:
            if END_SRC.match(line.lstrip()):
                blocks.append((start, buf))
                start = None
                buf = None
            else:
                buf.append(line.rstrip("\n"))
    return blocks


def find_by_function(blocks, name):
    for line_no, body in blocks:
        for bline in body:
            if re.match(rf"\(def(un|macro|method|var|parameter|class|struct|package)\s+{re.escape(name)}\b", bline):
                return line_no, body
    return None, None


def find_by_index(blocks, idx):
    if 1 <= idx <= len(blocks):
        return blocks[idx - 1]
    return None, None


def main():
    parser = argparse.ArgumentParser(description="Extract lisp blocks from org files")
    parser.add_argument("file", help=".org file to extract from")
    parser.add_argument("--block", type=int, default=None, help="Block number (1-based)")
    parser.add_argument("--function", type=str, default=None, help="Function name to find")
    parser.add_argument("--package", type=str, default=None, help="(in-package ...) prefix")
    args = parser.parse_args()

    try:
        with open(args.file, encoding="utf-8") as f:
            lines = f.readlines()
    except FileNotFoundError:
        print(f"File not found: {args.file}", file=sys.stderr)
        return 1
    except Exception as e:
        print(f"Error reading {args.file}: {e}", file=sys.stderr)
        return 1

    blocks = extract_blocks(lines)

    if args.function:
        line_no, body = find_by_function(blocks, args.function)
        if body is None:
            print(f"No block found containing function '{args.function}'", file=sys.stderr)
            return 1
    elif args.block:
        line_no, body = find_by_index(blocks, args.block)
        if body is None:
            print(f"Block {args.block} not found (file has {len(blocks)} blocks)", file=sys.stderr)
            return 1
    else:
        # Print listing to stderr so piping still works
        for idx, (line_no, body) in enumerate(blocks, 1):
            first = (body or [""])[0][:60]
            print(f"  {idx}: line {line_no}: {first}", file=sys.stderr)
        print(f"\n{len(blocks)} total blocks", file=sys.stderr)
        return 0

    if args.package:
        pkg = args.package
        if not pkg.startswith(":"):
            pkg = f":{pkg}"
        print(f"(in-package {pkg})")
        print()

    for line in body:
        print(line)


if __name__ == "__main__":
    sys.exit(main())
