from __future__ import annotations
import re
from collections import defaultdict
from dataclasses import dataclass, field
from functools import cached_property, partial
from typing import TYPE_CHECKING
import panflute as pf
from import convert_text
from .helper import cancel_emph, cite_to_id_mode, cite_to_ref, merge_emph, parse_markdown_as_inline, to_emph
from .util import setup_logging
from typing import Union
from panflute.elements import Doc, Element
THM_DEF = list[Union[str, dict[str, str], dict[str, list[str]]]]
__version__: str = "2.0.0"
PARENT_COUNTERS: set[str] = {
STYLES: tuple[str, ...] = ("plain", "definition", "remark")
METADATA_KEY: str = "amsthm"
REF_REGEX = re.compile(r"^\\(ref|eqref)\{(.*)\}$")
LATEX_LIKE: set[str] = {"latex", "beamer"}
PLAIN_OR_DEF: set[str] = {"plain", "definition"}
logger = setup_logging()
[docs]def parse_info(info: str | None) -> list[Element]:
"""Convert theorem info to panflute AST inline elements."""
return [pf.Str(r"(")] + parse_markdown_as_inline(info) + [pf.Str(r")")] if info else []
class NewTheorem:
style: str
env_name: str
text: str = ""
parent_counter: str | None = None
shared_counter: str | None = None
numbered: bool = True
"""A LaTeX amsthm new theorem.
:param parent_counter: for LaTeX output, controlling the number of numbers in a theorem.
Should be used with counter_depth to match LaTeX and non-LaTeX output.
def __post_init__(self) -> None:
if self.env_name.endswith("*"):
self.env_name = self.env_name[:-1]
self.numbered = False
if not self.text:
logger.debug("Defaulting text to %s", self.env_name)
self.text = self.env_name
if (parent_counter := self.parent_counter) is not None and parent_counter not in PARENT_COUNTERS:
logger.warning("Unsupported parent_counter %s, ignoring.", parent_counter)
if self.numbered and parent_counter is not None and self.shared_counter is not None:
logger.warning("Dropping shared_counter as parent_counter is defined.")
self.shared_counter = None
def latex(self) -> str:
res = [r"\newtheorem"]
if not self.numbered:
elif self.shared_counter is None:
if self.parent_counter is None:
return "".join(res)
def class_name(self) -> str:
"""Name in pandoc div classes.
It cannot have space.
return self.env_name.replace(" ", "_")
def counter_name(self) -> str:
return self.env_name if self.shared_counter is None else self.shared_counter
class Proof(NewTheorem):
style: str = "proof"
env_name: str = "proof"
text: str = "proof"
parent_counter: str | None = None
shared_counter: str | None = None
numbered: bool = False
class DocOptions:
"""Document options.
:param: counter_depth: can be n=0-6 inclusive.
n means n+1 numbers shown in non-LaTeX outputs.
e.g. n=1 means x.y, where x is the heading 1 counter, y is the theorem counter.
Should be used with parent_counter to match LaTeX and non-LaTeX output.
theorems: dict[str, NewTheorem] = field(default_factory=dict)
counter_depth: int = COUNTER_DEPTH_DEFAULT
counter_ignore_headings: set[str] = field(default_factory=set)
def __post_init__(self) -> None:
self.counter_depth = int(self.counter_depth)
except ValueError:
logger.warning("counter_depth must be int, default to 1.")
self.counter_depth = COUNTER_DEPTH_DEFAULT
# initial count is zero
# should be += 1 before using
self.header_counters: list[int] = [0] * self.counter_depth
# from identifiers to numbers
self.identifiers: dict[str, str] = {}
[docs] def reset_theorem_counters(self) -> None:
self.theorem_counters: dict[str, int] = defaultdict(int)
def theorems_set(self) -> set[str]:
return set(self.theorems)
[docs] @classmethod
def from_doc(
doc: Doc,
) -> DocOptions:
options: dict[
dict[str, str | dict[str, str] | THM_DEF],
] = doc.get_metadata(METADATA_KEY, {})
name_to_text: dict[str, str] = options.get("name_to_text", {}) # type: ignore[assignment, arg-type]
parent_counter: str = options.get("parent_counter", None) # type: ignore[assignment, arg-type]
theorems: dict[str, NewTheorem] = {}
for style in STYLES:
option: THM_DEF = options.get(style, []) # type: ignore[assignment]
for opt in option:
if isinstance(opt, dict):
for key, value in opt.items():
# key
theorem = NewTheorem(style, key, text=name_to_text.get(key, ""), parent_counter=parent_counter)
theorems[theorem.class_name] = theorem
# value(s)
if isinstance(value, list):
for v in value:
theorem = NewTheorem(style, v, text=name_to_text.get(v, ""), shared_counter=key)
theorems[theorem.class_name] = theorem
v = value
theorem = NewTheorem(style, v, text=name_to_text.get(v, ""), shared_counter=key)
theorems[theorem.class_name] = theorem
key = opt
theorem = NewTheorem(style, key, text=name_to_text.get(key, ""), parent_counter=parent_counter)
theorems[theorem.class_name] = theorem
# proof is predefined in amsthm
theorems["proof"] = Proof()
return cls(
counter_depth=options.get("counter_depth", COUNTER_DEPTH_DEFAULT), # type: ignore[arg-type] # will be verified at __post_init__
counter_ignore_headings=set(options.get("counter_ignore_headings", set())),
def latex(self) -> str:
cur_style: str = ""
res: list[str] = []
for theorem in self.theorems.values():
# proof is predefined in amsthm
if not isinstance(theorem, Proof):
if != cur_style:
cur_style =
return "\n".join(res)
def to_panflute(self) -> pf.RawBlock:
return pf.RawBlock(self.latex, format="latex")
[docs]def prepare(doc: Doc) -> None:
doc._amsthm = options = DocOptions.from_doc(doc)
if doc.format in LATEX_LIKE:
doc.content.insert(0, options.to_panflute)
[docs]def amsthm(elem: Element, doc: Doc) -> None:
"""General amsthm transformation working for all document types.
Essentially we replicate LaTeX amsthm behavior in this filter.
options: DocOptions = doc._amsthm
if isinstance(elem, pf.Header):
if elem.level <= options.counter_depth:
header_string = None
if (counter_ignore_headings := options.counter_ignore_headings) and (
header_string := pf.stringify(elem)
) in counter_ignore_headings:
logger.debug("Ignoring header %s in header_counters as it is in counter_ignore_headings", header_string)
# Header.level is 1-indexed, while list is 0-indexed
options.header_counters[elem.level - 1] += 1
# reset deeper levels
for i in range(elem.level, options.counter_depth):
options.header_counters[i] = 0
"Header encounter: %s, current counter: %s", header_string or elem, options.header_counters
elif isinstance(elem, pf.Div):
environments: set[str] = options.theorems_set.intersection(elem.classes)
if environments:
if len(environments) != 1:
logger.warning("Multiple environments found: %s", environments)
return None
environment = environments.pop()
theorem = options.theorems[environment]
info = elem.attributes.get("info", None)
id = elem.identifier
res = theorem.to_panflute_theorem_header(options, id, info)
# theorem body
if == "plain":
# insert in the beginning of the first block element
for r in reversed(res):
elem.content[0].content.insert(0, r)
except AttributeError:
# if fail, insert a Para before content
elem.content.insert(0, pf.Para(*res))
r = pf.RawInline("<span style='float: right'>â—»</span>", format="html")
# insert in the end of the last block element
if == "proof":
except AttributeError:
# if fail, append a Para
[docs]def resolve_ref(elem: Element, doc: Doc) -> pf.Str | None:
"""Resolve references to theorem numbers.
Consider this as post-process ref for general output formats.
options: DocOptions = doc._amsthm
# from [@...] to number
if isinstance(elem, pf.Cite):
if (temp := cite_to_id_mode(elem)) is not None and (id := temp[0]) in options.identifiers:
mode = temp[1]
# @[...]
if mode == "NormalCitation":
return pf.Str(f"({options.identifiers[id]})")
# @...
elif mode == "AuthorInText":
return pf.Str(options.identifiers[id])
logger.warning("Unknown citation mode %s from Cite: %s. Ignoring...", mode, elem)
return None
# from \ref{...} to number
elif isinstance(elem, pf.RawInline) and elem.format == "tex":
text = elem.text
if matches := REF_REGEX.findall(text):
if len(matches) != 1:
logger.warning("Ignoring ref matching in %s: %s", text, matches)
return None
ref_type, id = matches[0]
if id in options.identifiers:
if ref_type == "eqref":
return pf.Str(f"({options.identifiers[id]})")
return pf.Str(options.identifiers[id])
return None
[docs]def collect_ref_id(elem: Element, doc: Doc) -> None:
"""Only collect all amsthm environment id.
This should be used before the `amsthm_latex` filter.
This is done in 2 passes as the id may be cited/referenced earlier than definition.
Consider this as pre-process of ref for LaTeX output.
`options.identifiers` modified in-place.
# check if it is a Div, and the class is an amsthm environment
options: DocOptions = doc._amsthm
environments: set[str]
if isinstance(elem, pf.Div) and (environments := options.theorems_set.intersection(elem.classes)):
if len(environments) != 1:
logger.warning("Multiple environments found: %s", environments)
return None
if id := elem.identifier:
# in LaTeX output, we only need to keep a reference of the id
# the numbering (value of this dict) is handled by LaTeX
options.identifiers[id] = ""
return None
[docs]def amsthm_latex(elem: Element, doc: Doc) -> pf.RawBlock | None:
"""Transform amsthm defintion to LaTeX package specifications."""
# check if it is a Div, and the class is an amsthm environment
options: DocOptions = doc._amsthm
if isinstance(elem, pf.Div):
environments: set[str] = options.theorems_set.intersection(elem.classes)
if environments:
if len(environments) != 1:
logger.warning("Multiple environments found: %s", environments)
return None
environment = environments.pop()
theorem = options.theorems[environment]
div_content = pf.convert_text(elem, input_format="panflute", output_format="latex")
info = elem.attributes.get("info", None)
id = elem.identifier
res = [f"\\begin{{{theorem.env_name}}}"]
if info:
# wrap in Para for walk
ast = pf.Para(*parse_markdown_as_inline(info))
ast.walk(partial(cite_to_ref, check_id=options.identifiers))
ast = convert_text(ast, input_format="panflute", output_format="latex").strip()
res += [f"[{ast}]"]
if id:
return pf.RawBlock("".join(res), format="latex")
# check if pf.Cite is done inside cite_to_ref
return cite_to_ref(elem, doc, options.identifiers)
return None
[docs]def action1(elem: Element, doc: Doc) -> pf.RawBlock | None:
if doc.format in LATEX_LIKE:
collect_ref_id(elem, doc)
amsthm(elem, doc)
return None
[docs]def action2(elem: Element, doc: Doc) -> pf.Str | pf.RawInline | None:
if doc.format in LATEX_LIKE:
return amsthm_latex(elem, doc)
return resolve_ref(elem, doc)
[docs]def finalize(doc: Doc) -> None:
del doc._amsthm
[docs]def main(doc: Doc | None = None) -> None:
return pf.run_filters(
(action1, action2),