#!/usr/bin/env python3

from sys import stdout
from pycparser import c_ast, c_generator, parse_file

cross_gcc = "/opt/cross-m68k/bin/m68k-none-elf-gcc"

class Visitor(c_ast.NodeVisitor):
    gen = c_generator.CGenerator()

    def __init__(self):
        self.path = None

        self.calls = {}
        self.knows = {}

        self.declares = {}
        self.in_header = {}

        self.decl = {}

        self.max_decl = {}
        self.max_head = {}

        self.fix_head = {}
        self.fix_decl = {}

    def visit_FuncCall(self, node):
        if node.coord.file != self.path:
            return

        if type(node.name) is not c_ast.ID:
            return

        if self.path not in self.calls:
            self.calls[self.path] = set()

        self.calls[self.path].add(node.name.name)
        self.generic_visit(node)

    def visit_path(self, path, ast):
        self.path = path
        self.visit(ast)
        self.path = None

        for node in ast.ext:
            if type(node) is not c_ast.FuncDef:
                continue

            node.decl.storage = []
            decl = self.gen.visit(node.decl)

            if node.decl.name in self.decl and \
               self.decl[node.decl.name] != decl:
                raise Exception("declaration mismatch: {} vs. {}". \
                                format(self.decl[node.decl.name], decl))

            self.decl[node.decl.name] = decl

        for node in ast.ext:
            if type(node) is c_ast.Decl and \
               type(node.type) is c_ast.FuncDecl:
                decl = node

            elif type(node) is c_ast.FuncDef:
                decl = node.decl

            else:
                continue

            if path not in self.knows:
                self.knows[path] = set()

            self.knows[path].add(decl.name)

            if decl.coord.file != path:
                continue

            if type(node) is c_ast.Decl:
                if path not in self.max_decl:
                    self.max_decl[path] = -1;

                if decl.coord.line > self.max_decl[path]:
                    self.max_decl[path] = decl.coord.line

            if path not in self.declares:
                self.declares[path] = set()

            self.declares[path].add(decl.name)

            if path[-2:] == ".h":
                if decl.name in self.in_header:
                    raise Exception("duplicate definition of {}".format(decl.name))

                self.in_header[decl.name] = path

    def fix(self):
        for path, calls in sorted(self.calls.items()):
            knows = self.knows[path]
            needs = calls - knows

            for need in sorted(needs):
                if need in self.in_header:
                    if path not in self.fix_head:
                        self.fix_head[path] = set()

                    self.fix_head[path].add(self.in_header[need])

                else:
                    if path not in self.fix_decl:
                        self.fix_decl[path] = set()

                    self.fix_decl[path].add(self.decl[need])

            no = 1

            with open(path, "r") as f:
                for line in f:
                    if len(line) >= 8 and line[:8] == "#include":
                        self.max_head[path] = no

                    no += 1

        for path, _ in sorted(self.calls.items()):
            if path not in self.fix_head and path not in self.fix_decl:
                continue

            max_head = self.max_head[path] if path in self.max_head else None
            max_decl = self.max_decl[path] if path in self.max_decl else None

            if max_head is None:
                max_head = max_decl

            if max_decl is None:
                max_decl = max_head

            no = 1
            out = []
            extern = "extern "

            with open(path, "r") as f:
                for line in f:
                    out.append(line)

                    if "extern\t" in line:
                        extern = "extern\t"

                    if no == max_head and path in self.fix_head:
                        out.append("\n")

                        for head in sorted(self.fix_head[path]):
                            out.append("#include \"{}\"\n".format(head.split("/")[-1]))

                    if no == max_decl and path in self.fix_decl:
                        out.append("\n")

                        for decl in sorted(self.fix_decl[path]):
                            out.append(extern + decl + ";\n")

                    no += 1

            with open(path, "w") as f:
                for line in out:
                    f.write(line)

vis = Visitor()

with open("misc/c-files.txt", "r") as f:
    for path in f:
        path = path.rstrip()

        if path == "ram/wdfield.c": # breaks pycparser
            continue

        stdout.write("parsing {}                    \r".format(path))
        stdout.flush()

        ast = parse_file(path, use_cpp = True, cpp_path = cross_gcc,
                         cpp_args = ["-E", "-I", "include", "-include", "predef.h"])
        # ast.show()

        vis.visit_path(path, ast)

    print("")

vis.fix()
