#!/usr/bin/env python3
# pylint: disable=C0103,C0114,C0115,C0116,C0123,C0209,C0301,C0302,R0902,R0904,R0913,R0914,R0912,R0915,W0511,W0621
######################################################################

import argparse
import glob
import os
import re
import sys
import textwrap

# from pprint import pprint, pformat


# This class is used to represents both AstNode and DfgVertex sub-types
class Node:

    def __init__(self, name, superClass, file=None, lineno=None):
        self._name = name
        self._superClass = superClass
        self._subClasses = []  # Initially list, but tuple after completion
        self._allSuperClasses = None  # Computed on demand after completion
        self._allSubClasses = None  # Computed on demand after completion
        self._typeId = None  # Concrete type identifier number for leaf classes
        self._typeIdMin = None  # Lowest type identifier number for class
        self._typeIdMax = None  # Highest type identifier number for class
        self._file = file  # File this class is defined in
        self._lineno = lineno  # Line this class is defined on
        self._ordIdx = None  # Ordering index of this class
        self._arity = -1  # Arity of node
        self._ops = {}  # Operands of node
        self._ptrs = []  # Pointer members of node (name, types)

    @property
    def name(self):
        return self._name

    @property
    def superClass(self):
        return self._superClass

    @property
    def isRoot(self):
        return self.superClass is None

    @property
    def isCompleted(self):
        return isinstance(self._subClasses, tuple)

    @property
    def file(self):
        return self._file

    @property
    def lineno(self):
        return self._lineno

    @property
    def ptrs(self):
        assert self.isCompleted
        return self._ptrs

    # Pre completion methods
    def addSubClass(self, subClass):
        assert not self.isCompleted
        self._subClasses.append(subClass)

    def addOp(self, n, name, monad, kind):
        assert 1 <= n <= 4
        self._ops[n] = (name, monad, kind)
        self._arity = max(self._arity, n)

    def getOp(self, n):
        assert 1 <= n <= 4
        op = self._ops.get(n, None)
        if op is not None:
            return op
        if not self.isRoot:
            return self.superClass.getOp(n)
        return None

    def addPtr(self, name, monad, kind):
        name = re.sub(r'^m_', '', name)
        self._ptrs.append({'name': name, 'monad': monad, 'kind': kind})

    # Computes derived properties over entire class hierarchy.
    # No more changes to the hierarchy are allowed once this was called
    def complete(self, typeId=0, ordIdx=0):
        assert not self.isCompleted
        # Sort sub-classes and convert to tuple, which marks completion
        self._subClasses = tuple(
            sorted(self._subClasses,
                   key=lambda _: (bool(_._subClasses), _.name)))  # pylint: disable=protected-access

        self._ordIdx = ordIdx
        ordIdx = ordIdx + 1

        if self.isRoot:
            self._arity = 0
        else:
            self._arity = max(self._arity, self._superClass.arity)

        # Leaves
        if self.isLeaf:
            self._typeId = typeId
            return typeId + 1, ordIdx

        # Non-leaves
        for subClass in self._subClasses:
            typeId, ordIdx = subClass.complete(typeId, ordIdx)
        return typeId, ordIdx

    # Post completion methods
    @property
    def subClasses(self):
        assert self.isCompleted
        return self._subClasses

    @property
    def isLeaf(self):
        assert self.isCompleted
        return not self.subClasses

    @property
    def allSuperClasses(self):
        assert self.isCompleted
        if self._allSuperClasses is None:
            if self.superClass is None:
                self._allSuperClasses = ()
            else:
                self._allSuperClasses = self.superClass.allSuperClasses + (
                    self.superClass, )
        return self._allSuperClasses

    @property
    def allSubClasses(self):
        assert self.isCompleted
        if self._allSubClasses is None:
            if self.isLeaf:
                self._allSubClasses = ()
            else:
                self._allSubClasses = self.subClasses + tuple(
                    _ for subClass in self.subClasses
                    for _ in subClass.allSubClasses)
        return self._allSubClasses

    @property
    def typeId(self):
        assert self.isCompleted
        assert self.isLeaf
        return self._typeId

    @property
    def typeIdMin(self):
        assert self.isCompleted
        if self.isLeaf:
            return self.typeId
        if self._typeIdMin is None:
            self._typeIdMin = min(_.typeIdMin for _ in self.allSubClasses)
        return self._typeIdMin

    @property
    def typeIdMax(self):
        assert self.isCompleted
        if self.isLeaf:
            return self.typeId
        if self._typeIdMax is None:
            self._typeIdMax = max(_.typeIdMax for _ in self.allSubClasses)
        return self._typeIdMax

    @property
    def ordIdx(self):
        assert self.isCompleted
        return self._ordIdx

    @property
    def arity(self):
        assert self.isCompleted
        return self._arity

    def isSubClassOf(self, other):
        assert self.isCompleted
        if self is other:
            return True
        return self in other.allSubClasses


AstNodes = {}
AstNodeList = None

DfgVertices = {}
DfgVertexList = None

ClassRefs = {}
Stages = {}


class Cpt:

    def __init__(self):
        self.did_out_tree = False
        self.in_filename = ""
        self.in_linenum = 1
        self.out_filename = ""
        self.out_linenum = 1
        self.out_lines = []
        self.tree_skip_visit = {}
        self.treeop = {}
        self._exec_nsyms = 0
        self._exec_syms = {}

    def error(self, txt):
        sys.exit("%%Error: %s:%d: %s" %
                 (self.in_filename, self.in_linenum, txt))

    def print(self, txt):
        self.out_lines.append(txt)

    def output_func(self, func):
        self.out_lines.append(func)

    def _output_line(self):
        self.print("#line " + str(self.out_linenum + 2) + " \"" +
                   self.out_filename + "\"\n")

    def process(self, in_filename, out_filename):
        self.in_filename = in_filename
        self.out_filename = out_filename
        ln = 0
        didln = False

        # Read the file and parse into list of functions that generate output
        with open(self.in_filename, "r", encoding="utf8") as fhi:
            for line in fhi:
                ln += 1
                if not didln:
                    self.print("#line " + str(ln) + " \"" + self.in_filename +
                               "\"\n")
                    didln = True
                match = re.match(r'^\s+(TREE.*)$', line)
                if match:
                    func = match.group(1)
                    self.in_linenum = ln
                    self.print("//" + line)
                    self.output_func(lambda self: self._output_line())
                    self.tree_line(func)
                    didln = False
                elif not re.match(r'^\s*/[/\*]\s*TREE', line) and re.search(
                        r'\s+TREE', line):
                    self.error("Unknown astgen line: " + line)
                else:
                    self.print(line)

        # Put out the resultant file, if the list has a reference to a
        # function, then call that func to generate output
        with open_file(self.out_filename) as fho:
            togen = self.out_lines
            for line in togen:
                if isinstance(line, str):
                    self.out_lines = [line]
                else:
                    self.out_lines = []
                    line(self)  # lambda call
                for out in self.out_lines:
                    for _ in re.findall(r'\n', out):
                        self.out_linenum += 1
                    fho.write(out)

    def tree_line(self, func):
        func = re.sub(r'\s*//.*$', '', func)
        func = re.sub(r'\s*;\s*$', '', func)

        # doflag "S" indicates an op specifying short-circuiting for a type.
        match = re.search(
            #       1   2                 3                  4
            r'TREEOP(1?)([ACSV]?)\s*\(\s*\"([^\"]*)\"\s*,\s*\"([^\"]*)\"\s*\)',
            func)
        match_skip = re.search(r'TREE_SKIP_VISIT\s*\(\s*\"([^\"]*)\"\s*\)',
                               func)

        if match:
            order = match.group(1)
            doflag = match.group(2)
            fromn = match.group(3)
            to = match.group(4)
            # self.print("// $fromn $to\n")
            if not self.did_out_tree:
                self.did_out_tree = True
                self.output_func(lambda self: self.tree_match_base())
            match = re.search(r'Ast([a-zA-Z0-9]+)\s*\{(.*)\}\s*$', fromn)
            if not match:
                self.error("Can't parse from function: " + func)
            typen = match.group(1)
            subnodes = match.group(2)
            if AstNodes[typen].isRoot:
                self.error("Unknown AstNode typen: " + typen + ": in " + func)

            mif = ""
            if doflag == '':
                mif = "m_doNConst"
            elif doflag == 'A':
                mif = ""
            elif doflag == 'C':
                mif = "m_doCpp"
            elif doflag == 'S':
                mif = "m_doNConst"  # Not just for m_doGenerate
            elif doflag == 'V':
                mif = "m_doV"
            else:
                self.error("Unknown flag: " + doflag)
            subnodes = re.sub(r',,', '__ESCAPEDCOMMA__', subnodes)
            for subnode in re.split(r'\s*,\s*', subnodes):
                subnode = re.sub(r'__ESCAPEDCOMMA__', ',', subnode)
                if re.match(r'^\$([a-zA-Z0-9]+)$', subnode):
                    continue  # "$lhs" is just a comment that this op has a lhs
                subnodeif = subnode
                subnodeif = re.sub(
                    r'\$([a-zA-Z0-9]+)\.cast([A-Z][A-Za-z0-9]+)$',
                    r'VN_IS(nodep->\1(),\2)', subnodeif)
                subnodeif = re.sub(r'\$([a-zA-Z0-9]+)\.([a-zA-Z0-9]+)$',
                                   r'nodep->\1()->\2()', subnodeif)
                subnodeif = self.add_nodep(subnodeif)
                if mif != "" and subnodeif != "":
                    mif += " && "
                mif += subnodeif

            exec_func = self.treeop_exec_func(to)
            exec_func = re.sub(
                r'([-()a-zA-Z0-9_>]+)->cast([A-Z][A-Za-z0-9]+)\(\)',
                r'VN_CAST(\1,\2)', exec_func)

            if typen not in self.treeop:
                self.treeop[typen] = []
            n = len(self.treeop[typen])
            typefunc = {
                'order': order,
                'comment': func,
                'match_func': "match_" + typen + "_" + str(n),
                'match_if': mif,
                'exec_func': exec_func,
                'uinfo': re.sub(r'[ \t\"\{\}]+', ' ', func),
                'uinfo_level': (0 if re.match(r'^!', to) else 7),
                'short_circuit': (doflag == 'S'),
            }
            self.treeop[typen].append(typefunc)

        elif match_skip:
            typen = match_skip.group(1)
            self.tree_skip_visit[typen] = 1
            if typen not in AstNodes:
                self.error("Unknown node type: " + typen)

        else:
            self.error("Unknown astgen op: " + func)

    @staticmethod
    def add_nodep(strg):
        strg = re.sub(r'\$([a-zA-Z0-9]+)', r'nodep->\1()', strg)
        return strg

    def _exec_syms_recurse(self, aref):
        for sym in aref:
            if isinstance(sym, list):
                self._exec_syms_recurse(sym)
            elif re.search(r'^\$.*', sym):
                if sym not in self._exec_syms:
                    self._exec_nsyms += 1
                    self._exec_syms[sym] = "arg" + str(self._exec_nsyms) + "p"

    def _exec_new_recurse(self, aref):
        out = "new " + aref[0] + "(nodep->fileline()"
        first = True
        for sym in aref:
            if first:
                first = False
                continue
            out += ", "
            if isinstance(sym, list):
                out += self._exec_new_recurse(sym)
            elif re.match(r'^\$.*', sym):
                out += self._exec_syms[sym]
            else:
                out += sym
        return out + ")"

    def treeop_exec_func(self, func):
        out = ""
        func = re.sub(r'^!', '', func)

        if re.match(r'^\s*[a-zA-Z0-9]+\s*\(', func):  # Function call
            outl = re.sub(r'\$([a-zA-Z0-9]+)', r'nodep->\1()', func)
            out += outl + ";"
        elif re.match(r'^\s*Ast([a-zA-Z0-9]+)\s*\{\s*(.*)\s*\}$', func):
            aref = None
            # Recursive array with structure to form
            astack = []
            forming = ""
            argtext = func + "\000"  # EOF character
            for tok in argtext:
                if tok == "\000":
                    pass
                elif re.match(r'\s+', tok):
                    pass
                elif tok == "{":
                    newref = [forming]
                    if not aref:
                        aref = []
                    aref.append(newref)
                    astack.append(aref)
                    aref = newref
                    forming = ""
                elif tok == "}":
                    if forming:
                        aref.append(forming)
                    if len(astack) == 0:
                        self.error("Too many } in execution function: " + func)
                    aref = astack.pop()
                    forming = ""
                elif tok == ",":
                    if forming:
                        aref.append(forming)
                    forming = ""
                else:
                    forming += tok
            if not (aref and len(aref) == 1):
                self.error("Badly formed execution function: " + func)
            aref = aref[0]

            # Assign numbers to each $ symbol
            self._exec_syms = {}
            self._exec_nsyms = 0
            self._exec_syms_recurse(aref)

            for sym in sorted(self._exec_syms.keys(),
                              key=lambda val: self._exec_syms[val]):
                argnp = self._exec_syms[sym]
                arg = self.add_nodep(sym)
                out += "AstNodeExpr* " + argnp + " = " + arg + "->unlinkFrBack();\n"

            out += "AstNodeExpr* newp = " + self._exec_new_recurse(
                aref) + ";\n"
            out += "nodep->replaceWith(newp);"
            out += "VL_DO_DANGLING(nodep->deleteTree(), nodep);"
        elif func == "NEVER":
            out += "nodep->v3fatalSrc(\"Executing transform that was NEVERed\");"
        elif func == "DONE":
            pass
        else:
            self.error("Unknown execution function format: " + func + "\n")
        return out

    def tree_match_base(self):
        self.tree_match()
        self.tree_base()

    def tree_match(self):
        self.print(
            "    // TREEOP functions, each return true if they matched & transformed\n"
        )
        for base in sorted(self.treeop.keys()):
            for typefunc in self.treeop[base]:
                self.print("    // Generated by astgen\n")
                self.print("    bool " + typefunc['match_func'] + "(Ast" +
                           base + "* nodep) {\n")
                self.print("\t// " + typefunc['comment'] + "\n")
                self.print("\tif (" + typefunc['match_if'] + ") {\n")
                self.print("\t    UINFO(" + str(typefunc['uinfo_level']) +
                           ", cvtToHex(nodep)" + " << \" " +
                           typefunc['uinfo'] + "\\n\");\n")
                self.print("\t    " + typefunc['exec_func'] + "\n")
                self.print("\t    return true;\n")
                self.print("\t}\n")
                self.print("\treturn false;\n")
                self.print("    }\n", )

    def tree_base(self):
        self.print("    // TREEOP visitors, call each base type's match\n")
        self.print(
            "    // Bottom class up, as more simple transforms are generally better\n"
        )
        for node in AstNodeList:
            out_for_type_sc = []
            out_for_type = []
            classes = list(node.allSuperClasses)
            classes.append(node)
            for base in classes:
                base = base.name
                if base not in self.treeop:
                    continue
                for typefunc in self.treeop[base]:
                    lines = [
                        "        if (" + typefunc['match_func'] +
                        "(nodep)) return;\n"
                    ]
                    if typefunc['short_circuit']:  # short-circuit match fn
                        out_for_type_sc.extend(lines)
                    else:  # Standard match fn
                        if typefunc[
                                'order']:  # TREEOP1's go in front of others
                            out_for_type = lines + out_for_type
                        else:
                            out_for_type.extend(lines)

            # We need to deal with two cases. For short circuited functions we
            # evaluate the LHS, then apply the short-circuit matches, then
            # evaluate the RHS and possibly THS (ternary operators may
            # short-circuit) and apply all the other matches.

            # For types without short-circuits, we just use iterateChildren, which
            # saves one comparison.
            if len(out_for_type_sc) > 0:  # Short-circuited types
                self.print(
                    "    // Generated by astgen with short-circuiting\n" +
                    "    void visit(Ast" + node.name +
                    "* nodep) override {\n" +
                    "      iterateAndNextNull(nodep->{op1}());\n".format(
                        op1=node.getOp(1)[0]) + "".join(out_for_type_sc))
                if out_for_type[0]:
                    self.print(
                        "      iterateAndNextNull(nodep->{op2}());\n".format(
                            op2=node.getOp(2)[0]))
                    if node.isSubClassOf(AstNodes["NodeTriop"]):
                        self.print(
                            "      iterateAndNextNull(nodep->{op3}());\n".
                            format(op3=node.getOp(3)[0]))
                    self.print("".join(out_for_type) + "    }\n")
            elif len(out_for_type) > 0:  # Other types with something to print
                skip = node.name in self.tree_skip_visit
                gen = "Gen" if skip else ""
                virtual = "virtual " if skip else ""
                override = "" if skip else " override"
                self.print(
                    "    // Generated by astgen\n" + "    " + virtual +
                    "void visit" + gen + "(Ast" + node.name + "* nodep)" +
                    override + " {\n" +
                    ("" if skip else "        iterateChildren(nodep);\n") +
                    ''.join(out_for_type) + "    }\n")


######################################################################
######################################################################


def partitionAndStrip(string, separator):
    return map(lambda _: _.strip(), string.partition(separator))


def parseOpType(string):
    match = re.match(r'^(\w+)\[(\w+)\]$', string)
    if match:
        monad, kind = match.groups()
        if monad not in ("Optional", "List"):
            return None
        kind = parseOpType(kind)
        if not kind or kind[0]:
            return None
        return monad, kind[1]
    if re.match(r'^Ast(\w+)$', string):
        return "", string[3:]
    return None


def read_types(filename, Nodes, prefix):
    hasErrors = False

    def error(lineno, message):
        nonlocal hasErrors
        print(filename + ":" + str(lineno) + ": %Error: " + message,
              file=sys.stderr)
        hasErrors = True

    node = None
    hasAstgenMembers = False

    def checkFinishedNode(node):
        nonlocal hasAstgenMembers
        if not node:
            return
        if not hasAstgenMembers:
            error(
                node.lineno,
                "'{p}{n}' does not contain 'ASTGEN_MEMBERS_{p}{n};'".format(
                    p=prefix, n=node.name))
        hasAstgenMembers = False

    with open(filename, "r", encoding="utf8") as fh:
        for (lineno, line) in enumerate(fh, start=1):
            line = line.strip()
            if not line:
                continue

            match = re.search(r'^\s*(class|struct)\s*(\S+)', line)
            if match:
                classn = match.group(2)
                match = re.search(r':\s*public\s+(\S+)', line)
                supern = match.group(1) if match else ""
                if re.search(prefix, supern):
                    classn = re.sub(r'^' + prefix, '', classn)
                    supern = re.sub(r'^' + prefix, '', supern)
                    if not supern:
                        sys.exit("%Error: '{p}{c}' has no super-class".format(
                            p=prefix, c=classn))
                    checkFinishedNode(node)
                    superClass = Nodes[supern]
                    node = Node(classn, superClass, filename, lineno)
                    superClass.addSubClass(node)
                    Nodes[classn] = node
            if not node:
                continue

            if re.match(r'^\s*ASTGEN_MEMBERS_' + prefix + node.name + ';',
                        line):
                hasAstgenMembers = True

            if prefix != "Ast":
                continue

            match = re.match(r'^\s*//\s*@astgen\s+(.*)$', line)
            if match:
                decl = re.sub(r'//.*$', '', match.group(1))
                what, sep, rest = partitionAndStrip(decl, ":=")
                what = re.sub(r'\s+', ' ', what)
                if not sep:
                    error(
                        lineno,
                        "Malformed '@astgen' directive (expecting '<keywords> := <description>'): "
                        + decl)
                elif what in ("op1", "op2", "op3", "op4"):
                    n = int(what[-1])
                    ident, sep, kind = partitionAndStrip(rest, ":")
                    ident = ident.strip()
                    if not sep or not re.match(r'^\w+$', ident):
                        error(
                            lineno, "Malformed '@astgen " + what +
                            "' directive (expecting '" + what +
                            " := <identifier> : <type>': " + decl)
                    else:
                        kind = parseOpType(kind)
                        if not kind:
                            error(
                                lineno, "Bad type for '@astgen " + what +
                                "' (expecting Ast*, Optional[Ast*], or List[Ast*]):"
                                + decl)
                        elif node.getOp(n) is not None:
                            error(
                                lineno, "Already defined " + what + " for " +
                                node.name)
                        else:
                            node.addOp(n, ident, *kind)
                elif what in ("alias op1", "alias op2", "alias op3",
                              "alias op4"):
                    n = int(what[-1])
                    ident = rest.strip()
                    if not re.match(r'^\w+$', ident):
                        error(
                            lineno, "Malformed '@astgen " + what +
                            "' directive (expecting '" + what +
                            " := <identifier>': " + decl)
                    else:
                        op = node.getOp(n)
                        if op is None:
                            error(lineno,
                                  "Aliased op" + str(n) + " is not defined")
                        else:
                            node.addOp(n, ident, *op[1:])
                elif what == "ptr":
                    ident, sep, kind = partitionAndStrip(rest, ":")
                    ident = ident.strip()
                    kind = parseOpType(kind)
                    if not kind:
                        error(
                            lineno, "Bad type for '@astgen " + what +
                            "' (expecting Ast*, Optional[Ast*], or List[Ast*]):"
                            + decl)
                    if not re.match(r'^m_(\w+)$', ident):
                        error(
                            lineno, "Malformed '@astgen ptr'"
                            " identifier (expecting m_ in '" + ident + "')")
                    else:
                        node.addPtr(ident, *kind)
                else:
                    error(
                        lineno,
                        "Malformed @astgen what (expecting 'op1'..'op4'," +
                        " 'alias op1'.., 'ptr'): " + what)
            else:
                line = re.sub(r'//.*$', '', line)
                if re.match(r'.*[Oo]p[1-9].*', line):
                    error(lineno,
                          "Use generated accessors to access op<N> operands")

            if re.match(
                    r'^\s*Ast[A-Z][A-Za-z0-9_]+\s*\*(\s*const)?\s+m_[A-Za-z0-9_]+\s*;',
                    line):
                error(lineno,
                      "Use '@astgen ptr' for Ast pointer members: " + line)

        checkFinishedNode(node)
    if hasErrors:
        sys.exit("%Error: Stopping due to errors reported above")


def check_types(sortedTypes, prefix, abstractPrefix):
    baseClass = prefix + abstractPrefix

    # Check all leaf types are not AstNode* and non-leaves are AstNode*
    for node in sortedTypes:
        if re.match(r'^' + abstractPrefix, node.name):
            if node.isLeaf:
                sys.exit(
                    "%Error: Final {b} subclasses must not be named {b}*: {p}{n}"
                    .format(b=baseClass, p=prefix, n=node.name))
        else:
            if not node.isLeaf:
                sys.exit(
                    "%Error: Non-final {b} subclasses must be named {b}*: {p}{n}"
                    .format(b=baseClass, p=prefix, n=node.name))

    # Check ordering of node definitions
    hasOrderingError = False

    files = tuple(
        sorted(set(_.file for _ in sortedTypes if _.file is not None)))

    for file in files:
        nodes = tuple(filter(lambda _, f=file: _.file == f, sortedTypes))
        expectOrder = tuple(sorted(nodes, key=lambda _: (_.isLeaf, _.ordIdx)))
        actualOrder = tuple(sorted(nodes, key=lambda _: _.lineno))
        expect = {
            node: pred
            for pred, node in zip((None, ) + expectOrder[:-1], expectOrder)
        }
        actual = {
            node: pred
            for pred, node in zip((None, ) + actualOrder[:-1], actualOrder)
        }
        for node in nodes:
            if expect[node] != actual[node]:
                hasOrderingError = True
                pred = expect[node]
                print(
                    "{file}:{lineno}: %Error: Definition of '{p}{n}' is out of order. Should be {where}."
                    .format(file=file,
                            lineno=node.lineno,
                            p=prefix,
                            n=node.name,
                            where=("right after '" + prefix + pred.name +
                                   "'" if pred else "first in file")),
                    file=sys.stderr)

    if hasOrderingError:
        sys.exit(
            "%Error: Stopping due to out of order definitions listed above")


def read_stages(filename):
    with open(filename, "r", encoding="utf8") as fh:
        n = 100
        for line in fh:
            line = re.sub(r'//.*$', '', line)
            if re.match(r'^\s*$', line):
                continue
            match = re.search(r'\s([A-Za-z0-9]+)::', line)
            if match:
                stage = match.group(1) + ".cpp"
                if stage not in Stages:
                    Stages[stage] = n
                    n += 1


def read_refs(filename):
    basename = re.sub(r'.*/', '', filename)
    with open(filename, "r", encoding="utf8") as fh:
        for line in fh:
            line = re.sub(r'//.*$', '', line)
            for match in re.finditer(r'\bnew\s*(Ast[A-Za-z0-9_]+)', line):
                ref = match.group(1)
                if ref not in ClassRefs:
                    ClassRefs[ref] = {'newed': {}, 'used': {}}
                ClassRefs[ref]['newed'][basename] = 1
            for match in re.finditer(r'\b(Ast[A-Za-z0-9_]+)', line):
                ref = match.group(1)
                if ref not in ClassRefs:
                    ClassRefs[ref] = {'newed': {}, 'used': {}}
                ClassRefs[ref]['used'][basename] = 1
            for match in re.finditer(
                    r'(VN_IS|VN_AS|VN_CAST)\([^.]+, ([A-Za-z0-9_]+)', line):
                ref = "Ast" + match.group(2)
                if ref not in ClassRefs:
                    ClassRefs[ref] = {'newed': {}, 'used': {}}
                ClassRefs[ref]['used'][basename] = 1


def open_file(filename):
    fh = open(filename, "w", encoding="utf8")  # pylint: disable=consider-using-with
    if re.search(r'\.txt$', filename):
        fh.write("// Generated by astgen\n")
    else:
        fh.write(
            '// Generated by astgen // -*- mode: C++; c-file-style: "cc-mode" -*-'
            + "\n")
    return fh


# ---------------------------------------------------------------------


def write_report(filename):
    with open_file(filename) as fh:

        fh.write(
            "Processing stages (approximate, based on order in Verilator.cpp):\n"
        )
        for classn in sorted(Stages.keys(), key=lambda val: Stages[val]):
            fh.write("  " + classn + "\n")

        fh.write("\nClasses:\n")
        for node in AstNodeList:
            fh.write("  class Ast%-17s\n" % node.name)
            fh.write("    arity:  {}\n".format(node.arity))
            fh.write("    parent: ")
            for superClass in node.allSuperClasses:
                if not superClass.isRoot:
                    fh.write("Ast%-12s " % superClass.name)
            fh.write("\n")
            fh.write("    childs:  ")
            for subClass in node.allSubClasses:
                fh.write("Ast%-12s " % subClass.name)
            fh.write("\n")
            if ("Ast" + node.name) in ClassRefs:  # pylint: disable=superfluous-parens
                refs = ClassRefs["Ast" + node.name]
                fh.write("    newed:  ")
                for stage in sorted(refs['newed'].keys(),
                                    key=lambda val: Stages[val]
                                    if (val in Stages) else -1):
                    fh.write(stage + "  ")
                fh.write("\n")
                fh.write("    used:   ")
                for stage in sorted(refs['used'].keys(),
                                    key=lambda val: Stages[val]
                                    if (val in Stages) else -1):
                    fh.write(stage + "  ")
                fh.write("\n")
            fh.write("\n")


################################################################################
# Common code generation
################################################################################


def write_forward_class_decls(prefix, nodeList):
    with open_file("V3{p}__gen_forward_class_decls.h".format(p=prefix)) as fh:
        for node in nodeList:
            fh.write("class {p}{n:<17} // ".format(p=prefix,
                                                   n=node.name + ";"))
            for superClass in node.allSuperClasses:
                fh.write("{p}{n:<12} ".format(p=prefix, n=superClass.name))
            fh.write("\n")


def write_visitor_decls(prefix, nodeList):
    with open_file("V3{p}__gen_visitor_decls.h".format(p=prefix)) as fh:
        for node in nodeList:
            if not node.isRoot:
                fh.write("virtual void visit({p}{n}*);\n".format(p=prefix,
                                                                 n=node.name))


def write_visitor_defns(prefix, nodeList, visitor):
    with open_file("V3{p}__gen_visitor_defns.h".format(p=prefix)) as fh:
        variable = "nodep" if prefix == "Ast" else "vtxp"
        for node in nodeList:
            base = node.superClass
            if base is not None:
                fh.write(
                    "void {c}::visit({p}{n}* {v}) {{ visit(static_cast<{p}{b}*>({v})); }}\n"
                    .format(c=visitor,
                            p=prefix,
                            n=node.name,
                            b=base.name,
                            v=variable))


def write_type_enum(prefix, nodeList):
    root = next(_ for _ in nodeList if _.isRoot)
    with open_file("V3{p}__gen_type_enum.h".format(p=prefix)) as fh:

        fh.write("    enum en : uint16_t {\n")
        for node in sorted(filter(lambda _: _.isLeaf, nodeList),
                           key=lambda _: _.typeId):
            fh.write("        at{t} = {n},\n".format(t=node.name,
                                                     n=node.typeId))
        fh.write("        _ENUM_END = {n}\n".format(n=root.typeIdMax + 1))
        fh.write("    };\n")

        fh.write("    enum bounds : uint16_t {\n")
        for node in sorted(filter(lambda _: not _.isLeaf, nodeList),
                           key=lambda _: _.typeIdMin):
            fh.write("        first{t} = {n},\n".format(t=node.name,
                                                        n=node.typeIdMin))
            fh.write("        last{t}  = {n},\n".format(t=node.name,
                                                        n=node.typeIdMax))
        fh.write("        _BOUNDS_END\n")
        fh.write("    };\n")

        fh.write("    const char* ascii() const VL_MT_SAFE {\n")
        fh.write("        static const char* const names[_ENUM_END + 1] = {\n")
        for node in sorted(filter(lambda _: _.isLeaf, nodeList),
                           key=lambda _: _.typeId):
            fh.write('            "{T}",\n'.format(T=node.name.upper()))
        fh.write("            \"_ENUM_END\"\n")
        fh.write("        };\n")
        fh.write("        return names[m_e];\n")
        fh.write("    }\n")


def write_type_tests(prefix, nodeList):
    with open_file("V3{p}__gen_type_tests.h".format(p=prefix)) as fh:
        fh.write("// For internal use. They assume argument is not nullptr.\n")
        if prefix == "Ast":
            base = "AstNode"
            variable = "nodep"
            enum = "VNType"
        elif prefix == "Dfg":
            base = "DfgVertex"
            variable = "vtxp"
            enum = "VDfgType"
        for node in nodeList:
            fh.write(
                "template<> inline bool {b}::privateTypeTest<{p}{n}>(const {b}* {v}) {{ "
                .format(b=base, p=prefix, n=node.name, v=variable))
            if node.isRoot:
                fh.write("return true;")
            elif not node.isLeaf:
                fh.write(
                    "return static_cast<int>({v}->type()) >= static_cast<int>({e}::first{t}) && static_cast<int>({v}->type()) <= static_cast<int>({e}::last{t});"
                    .format(v=variable, e=enum, t=node.name))
            else:
                fh.write("return {v}->type() == {e}::at{t};".format(
                    v=variable, e=enum, t=node.name))
            fh.write(" }\n")


################################################################################
# Ast code generation
################################################################################


def write_ast_type_info(filename):
    with open_file(filename) as fh:
        for node in sorted(filter(lambda _: _.isLeaf, AstNodeList),
                           key=lambda _: _.typeId):
            opTypeList = []
            opNameList = []
            for n in range(1, 5):
                op = node.getOp(n)
                if not op:
                    opTypeList.append('OP_UNUSED')
                    opNameList.append('op{0}p'.format(n))
                else:
                    name, monad, _ = op
                    if not monad:
                        opTypeList.append('OP_USED')
                    elif monad == "Optional":
                        opTypeList.append('OP_OPTIONAL')
                    elif monad == "List":
                        opTypeList.append('OP_LIST')
                    opNameList.append(name)
            # opTypeStr = ', '.join(opTypeList)
            opTypeStr = ', '.join(
                ['VNTypeInfo::{0}'.format(s) for s in opTypeList])
            opNameStr = ', '.join(['"{0}"'.format(s) for s in opNameList])
            fh.write(
                '    {{ "Ast{name}", {{{opTypeStr}}}, {{{opNameStr}}}, sizeof(Ast{name}) }},\n'
                .format(
                    name=node.name,
                    opTypeStr=opTypeStr,
                    opNameStr=opNameStr,
                ))


def write_ast_impl(filename):
    with open_file(filename) as fh:

        def emitBlock(pattern, **fmt):
            fh.write(
                textwrap.indent(textwrap.dedent(pattern),
                                "    ").format(**fmt))

        for node in AstNodeList:
            if node.name == "Node":
                continue
            emitBlock("const char* Ast{t}::brokenGen() const {{\n",
                      t=node.name)
            if node.superClass.name != 'Node':
                emitBlock("    BROKEN_BASE_RTN(Ast{base}::brokenGen());\n",
                          base=node.superClass.name)
            for ptr in node.ptrs:
                if ptr['monad'] == 'Optional':
                    emitBlock(
                        "    BROKEN_RTN(m_{name} && !m_{name}->brokeExists());\n",
                        name=ptr['name'])
                else:
                    emitBlock("    BROKEN_RTN(!m_{name});\n" +
                              "    BROKEN_RTN(!m_{name}->brokeExists());\n",
                              name=ptr['name'])
            # Node's broken rules can be specialized by declaring broken()
            emitBlock("    return Ast{t}::broken(); }}\n", t=node.name)

            emitBlock("void Ast{t}::cloneRelinkGen() {{\n", t=node.name)
            if node.superClass.name != 'Node':
                emitBlock("    Ast{base}::cloneRelinkGen();\n",
                          base=node.superClass.name)
            for ptr in node.ptrs:
                emitBlock(
                    "    if (m_{name} && m_{name}->clonep()) m_{name} = m_{name}->clonep();\n",
                    name=ptr['name'],
                    kind=ptr['kind'])

            emitBlock("}}\n")


def write_ast_macros(filename):
    with open_file(filename) as fh:

        def emitBlock(pattern, **fmt):
            fh.write(
                textwrap.indent(textwrap.dedent(pattern),
                                "    ").format(**fmt).replace("\n", " \\\n"))

        for node in AstNodeList:
            fh.write("#define ASTGEN_MEMBERS_Ast{t} \\\n".format(t=node.name))
            any_ptr = False
            for ptr in node.ptrs:
                if not any_ptr:
                    fh.write("private: \\\n")
                    any_ptr = True
                emitBlock("Ast{kind}* m_{name} = nullptr;",
                          name=ptr['name'],
                          kind=ptr['kind'])
            if any_ptr:
                fh.write("public: \\\n")
            # TODO pointer accessors
            # for ptr in node.ptrs:
            #     emitBlock(
            #         ("{kind}* {name}() const {{ return m_{name}; }}\n" +
            #          "void {name}({kind}* nodep) {{ m_{name} = nodep; }}"),
            #         name=ptr['name'],
            #         kind=ptr['kind'])

            emitBlock('''\
            Ast{t}* unlinkFrBack(VNRelinker* linkerp = nullptr) {{
                return static_cast<Ast{t}*>(AstNode::unlinkFrBack(linkerp));
            }}
            Ast{t}* unlinkFrBackWithNext(VNRelinker* linkerp = nullptr) {{
                return static_cast<Ast{t}*>(AstNode::unlinkFrBackWithNext(linkerp));
            }}
            Ast{t}* cloneTree(bool cloneNext) {{
                return static_cast<Ast{t}*>(AstNode::cloneTree(cloneNext));
            }}
            Ast{t}* cloneTreePure(bool cloneNext) {{
                return static_cast<Ast{t}*>(AstNode::cloneTreePure(cloneNext));
            }}
            Ast{t}* clonep() const {{ return static_cast<Ast{t}*>(AstNode::clonep()); }}
            Ast{t}* addNext(Ast{t}* nodep) {{ return static_cast<Ast{t}*>(AstNode::addNext(this, nodep)); }}
            const char* brokenGen() const override;
            void cloneRelinkGen() override;
            ''',
                      t=node.name)

            if node.isLeaf:
                emitBlock('''\
                void accept(VNVisitorConst& v) override {{ v.visit(this); }}
                AstNode* clone() override {{ return new Ast{t}(*this); }}
                ''',
                          t=node.name)

            hiddenMethods = []

            for n in range(1, 5):
                op = node.getOp(n)
                if not op:
                    continue
                name, monad, kind = op
                retrieve = ("VN_DBG_AS(op{n}p(), {kind})" if kind != "Node"
                            else "op{n}p()").format(n=n, kind=kind)
                superOp = node.superClass.getOp(n)
                superName = None
                if superOp:
                    superName = superOp[0]
                    hiddenMethods.append(superName)
                if monad == "List":
                    emitBlock('''\
                    Ast{kind}* {name}() const VL_MT_STABLE {{ return {retrieve}; }}
                    void add{Name}(Ast{kind}* nodep) {{ addNOp{n}p(reinterpret_cast<AstNode*>(nodep)); }}
                    ''',
                              kind=kind,
                              name=name,
                              Name=name[0].upper() + name[1:],
                              n=n,
                              retrieve=retrieve)
                    if superOp:
                        hiddenMethods.append("add" + superName[0].upper() +
                                             superName[1:])
                elif monad == "Optional":
                    emitBlock('''\
                    Ast{kind}* {name}() const VL_MT_STABLE {{ return {retrieve}; }}
                    void {name}(Ast{kind}* nodep) {{ setNOp{n}p(reinterpret_cast<AstNode*>(nodep)); }}
                    ''',
                              kind=kind,
                              name=name,
                              n=n,
                              retrieve=retrieve)
                else:
                    emitBlock('''\
                    Ast{kind}* {name}() const VL_MT_STABLE {{ return {retrieve}; }}
                    void {name}(Ast{kind}* nodep) {{ setOp{n}p(reinterpret_cast<AstNode*>(nodep)); }}
                    ''',
                              kind=kind,
                              name=name,
                              n=n,
                              retrieve=retrieve)

            if hiddenMethods:
                fh.write("private: \\\n")
                for method in hiddenMethods:
                    fh.write("    using Ast{sup}::{method}; \\\n".format(
                        sup=node.superClass.name, method=method))
                fh.write("public: \\\n")

            fh.write(
                "    static_assert(true, \"\")\n")  # Swallowing the semicolon

            # Only care about leaf classes for the rest
            if node.isLeaf:
                fh.write(
                    "#define ASTGEN_SUPER_{t}(...) Ast{b}(VNType::at{t}, __VA_ARGS__)\n"
                    .format(t=node.name, b=node.superClass.name))
            fh.write("\n")


def write_ast_yystype(filename):
    with open_file(filename) as fh:
        for node in AstNodeList:
            fh.write("Ast{t}* {m}p;\n".format(t=node.name,
                                              m=node.name[0].lower() +
                                              node.name[1:]))


################################################################################
# DFG code generation
################################################################################


def write_dfg_macros(filename):
    with open_file(filename) as fh:

        def emitBlock(pattern, **fmt):
            fh.write(
                textwrap.indent(textwrap.dedent(pattern),
                                "    ").format(**fmt).replace("\n", " \\\n"))

        for node in DfgVertexList:
            fh.write("#define ASTGEN_MEMBERS_Dfg{t} \\\n".format(t=node.name))

            if node.isLeaf:
                emitBlock('''\
                static constexpr VDfgType dfgType() {{ return VDfgType::at{t}; }};
                void accept(DfgVisitor& v) override {{ v.visit(this); }}
                ''',
                          t=node.name)

            for n in range(1, node.arity + 1):
                name, _, _ = node.getOp(n)
                emitBlock('''\
                DfgVertex* {name}() const {{ return source<{n}>(); }}
                void {name}(DfgVertex* vtxp) {{ relinkSource<{n}>(vtxp); }}
                ''',
                          name=name,
                          n=n - 1)

            operandNames = tuple(
                node.getOp(n)[0] for n in range(1, node.arity + 1))
            if operandNames:
                emitBlock('''\
                          const std::string srcName(size_t idx) const override {{
                              static const char* names[{a}] = {{ {ns} }};
                              return names[idx];
                          }}
                          ''',
                          a=node.arity,
                          ns=", ".join(
                              map(lambda _: '"' + _ + '"', operandNames)))
            fh.write(
                "    static_assert(true, \"\")\n")  # Swallowing the semicolon


def write_dfg_auto_classes(filename):
    with open_file(filename) as fh:

        def emitBlock(pattern, **fmt):
            fh.write(textwrap.dedent(pattern).format(**fmt))

        for node in DfgVertexList:
            # Only generate code for automatically derived leaf nodes
            if (node.file is not None) or not node.isLeaf:
                continue

            emitBlock('''\
                      class Dfg{t} final : public Dfg{s} {{
                      public:
                          Dfg{t}(DfgGraph& dfg, FileLine* flp, AstNodeDType* dtypep)
                              : Dfg{s}{{dfg, dfgType(), flp, dtypep}} {{}}
                          ASTGEN_MEMBERS_Dfg{t};
                      }};
                      ''',
                      t=node.name,
                      s=node.superClass.name)
        fh.write("\n")


def write_dfg_ast_to_dfg(filename):
    with open_file(filename) as fh:
        for node in DfgVertexList:
            # Only generate code for automatically derived leaf nodes
            if (node.file is not None) or (not node.isLeaf):
                continue

            fh.write(
                "void visit(Ast{t}* nodep) override {{\n".format(t=node.name))
            fh.write(
                '    UASSERT_OBJ(!nodep->user1p(), nodep, "Already has Dfg vertex");\n\n'
            )
            fh.write("    if (unhandled(nodep)) return;\n\n")
            for i in range(node.arity):
                fh.write("    iterate(nodep->op{j}p());\n".format(j=i + 1))
                fh.write("    if (m_foundUnhandled) return;\n")
                fh.write(
                    '    UASSERT_OBJ(nodep->op{j}p()->user1p(), nodep, "Child {j} missing Dfg vertex");\n'
                    .format(j=i + 1))
            fh.write("\n")
            fh.write(
                "    Dfg{t}* const vtxp = makeVertex<Dfg{t}>(nodep, *m_dfgp);\n"
                .format(t=node.name))
            fh.write("    if (!vtxp) {\n")
            fh.write("        m_foundUnhandled = true;\n")
            fh.write("        ++m_ctx.m_nonRepNode;\n")
            fh.write("        return;\n")
            fh.write("    }\n\n")
            for i in range(node.arity):
                fh.write(
                    "    vtxp->relinkSource<{i}>(nodep->op{j}p()->user1u().to<DfgVertex*>());\n"
                    .format(i=i, j=i + 1))
            fh.write("\n")
            fh.write("    m_uncommittedVertices.push_back(vtxp);\n")
            fh.write("    nodep->user1p(vtxp);\n")
            fh.write("}\n")


def write_dfg_dfg_to_ast(filename):
    with open_file(filename) as fh:
        for node in DfgVertexList:
            # Only generate code for automatically derived leaf nodes
            if (node.file is not None) or (not node.isLeaf):
                continue

            fh.write(
                "void visit(Dfg{t}* vtxp) override {{\n".format(t=node.name))
            for i in range(node.arity):
                fh.write(
                    "    AstNodeExpr* const op{j}p = convertSource(vtxp->source<{i}>());\n"
                    .format(i=i, j=i + 1))
            fh.write(
                "    m_resultp = makeNode<Ast{t}>(vtxp".format(t=node.name))
            for i in range(node.arity):
                fh.write(", op{j}p".format(j=i + 1))
            fh.write(");\n")
            fh.write("}\n")


######################################################################
# main

parser = argparse.ArgumentParser(
    allow_abbrev=False,
    formatter_class=argparse.RawDescriptionHelpFormatter,
    description="""Generate V3Ast headers to reduce C++ code duplication.""",
    epilog=
    """Copyright 2002-2024 by Wilson Snyder. This program is free software; you
can redistribute it and/or modify it under the terms of either the GNU
Lesser General Public License Version 3 or the Perl Artistic License
Version 2.0.

SPDX-License-Identifier: LGPL-3.0-only OR Artistic-2.0""")

parser.add_argument('-I', action='store', help='source code include directory')
parser.add_argument('--astdef',
                    action='append',
                    help='add AST definition file (relative to -I)')
parser.add_argument('--dfgdef',
                    action='append',
                    help='add DFG definition file (relative to -I)')
parser.add_argument('--classes',
                    action='store_true',
                    help='makes class declaration files')
parser.add_argument('--debug', action='store_true', help='enable debug')

parser.add_argument('infiles', nargs='*', help='list of input .cpp filenames')

Args = parser.parse_args()

###############################################################################
# Read AstNode definitions
###############################################################################

# Set up the root AstNode type. It is standalone so we don't need to parse the
# sources for this.
AstNodes["Node"] = Node("Node", None)

# Read AstNode definitions
for filename in Args.astdef:
    read_types(os.path.join(Args.I, filename), AstNodes, "Ast")

# Compute derived properties over the whole AstNode hierarchy
AstNodes["Node"].complete()

AstNodeList = tuple(map(lambda _: AstNodes[_], sorted(AstNodes.keys())))

check_types(AstNodeList, "Ast", "Node")

###############################################################################
# Read and generate DfgVertex definitions
###############################################################################

# Set up the root DfgVertex type and some other hand-written base types.
# These are standalone so we don't need to parse the sources for this.
DfgVertices["Vertex"] = Node("Vertex", None)
DfgVertices["VertexUnary"] = Node("VertexUnary", DfgVertices["Vertex"])
DfgVertices["Vertex"].addSubClass(DfgVertices["VertexUnary"])
DfgVertices["VertexBinary"] = Node("VertexBinary", DfgVertices["Vertex"])
DfgVertices["Vertex"].addSubClass(DfgVertices["VertexBinary"])
DfgVertices["VertexTernary"] = Node("VertexTernary", DfgVertices["Vertex"])
DfgVertices["Vertex"].addSubClass(DfgVertices["VertexTernary"])
DfgVertices["VertexVariadic"] = Node("VertexVariadic", DfgVertices["Vertex"])
DfgVertices["Vertex"].addSubClass(DfgVertices["VertexVariadic"])

# Read DfgVertex definitions
for filename in Args.dfgdef:
    read_types(os.path.join(Args.I, filename), DfgVertices, "Dfg")

# Add the DfgVertex sub-types automatically derived from AstNode sub-types
for node in AstNodeList:
    # Ignore the hierarchy for now
    if not node.isLeaf:
        continue

    # Ignore any explicitly defined vertex
    if node.name in DfgVertices:
        continue

    if node.isSubClassOf(AstNodes["NodeUniop"]):
        base = DfgVertices["VertexUnary"]
    elif node.isSubClassOf(AstNodes["NodeBiop"]):
        base = DfgVertices["VertexBinary"]
    elif node.isSubClassOf(AstNodes["NodeTriop"]):
        base = DfgVertices["VertexTernary"]
    else:
        continue

    vertex = Node(node.name, base)
    DfgVertices[node.name] = vertex
    base.addSubClass(vertex)

    for n in range(1, node.arity + 1):
        op = node.getOp(n)
        if op is not None:
            name, monad, kind = op
            assert monad == "", "Cannot represent AstNode as DfgVertex"
            vertex.addOp(n, name, "", "")

# Compute derived properties over the whole DfgVertex hierarchy
DfgVertices["Vertex"].complete()

DfgVertexList = tuple(map(lambda _: DfgVertices[_],
                          sorted(DfgVertices.keys())))

check_types(DfgVertexList, "Dfg", "Vertex")

###############################################################################
# Read additional files
###############################################################################

read_stages(Args.I + "/Verilator.cpp")

source_files = glob.glob(Args.I + "/*.y")
source_files.extend(glob.glob(Args.I + "/*.h"))
source_files.extend(glob.glob(Args.I + "/*.cpp"))
for filename in source_files:
    read_refs(filename)

###############################################################################
# Generate output
###############################################################################

if Args.classes:
    write_report("V3Ast__gen_report.txt")
    # Write Ast code
    write_forward_class_decls("Ast", AstNodeList)
    write_visitor_decls("Ast", AstNodeList)
    write_visitor_defns("Ast", AstNodeList, "VNVisitorConst")
    write_type_enum("Ast", AstNodeList)
    write_type_tests("Ast", AstNodeList)
    write_ast_type_info("V3Ast__gen_type_info.h")
    write_ast_impl("V3Ast__gen_impl.h")
    write_ast_macros("V3Ast__gen_macros.h")
    write_ast_yystype("V3Ast__gen_yystype.h")
    # Write Dfg code
    write_forward_class_decls("Dfg", DfgVertexList)
    write_visitor_decls("Dfg", DfgVertexList)
    write_visitor_defns("Dfg", DfgVertexList, "DfgVisitor")
    write_type_enum("Dfg", DfgVertexList)
    write_type_tests("Dfg", DfgVertexList)
    write_dfg_macros("V3Dfg__gen_macros.h")
    write_dfg_auto_classes("V3Dfg__gen_auto_classes.h")
    write_dfg_ast_to_dfg("V3Dfg__gen_ast_to_dfg.h")
    write_dfg_dfg_to_ast("V3Dfg__gen_dfg_to_ast.h")

for cpt in Args.infiles:
    if not re.search(r'.cpp$', cpt):
        sys.exit("%Error: Expected argument to be .cpp file: " + cpt)
    cpt = re.sub(r'.cpp$', '', cpt)
    Cpt().process(in_filename=Args.I + "/" + cpt + ".cpp",
                  out_filename=cpt + "__gen.cpp")

######################################################################
# Local Variables:
# compile-command: "touch src/V3AstNodeExpr.h ; v4make"
# End:
