#!/usr/bin/env python3

from sys import stdout
from os import unlink
from pickle import load
from pycparser import c_ast, parse_file, c_generator
from re import subn, escape

cross_gcc = "/opt/cross-m68k/bin/m68k-none-elf-gcc"

class Visitor(c_ast.NodeVisitor):
    def reset(self):
        self.path = None
        self.orig = None
        self.align = None

        self.func_def = False
        self.prev_line = 0
        self.this_line = None
        self.subs = []
        self.sub_def = False

    def __init__(self):
        with open("proto.dat", "rb") as f:
            proto = load(f)

        self.proto = {}

        for p in proto:
            self.proto[p.name] = p

        self.gen = c_generator.CGenerator()
        self.reset()

    def visit_Decl(self, node):
        if node.coord.file != self.path:
            return

        sub_def = self.func_def
        self.func_def = False

        if type(node.type) is not c_ast.FuncDecl:
            return

        if node.coord.line <= self.prev_line:
            return

        if self.this_line is None:
            self.this_line = node.coord.line

        if node.coord.line != self.this_line:
            return

        self.sub_def = sub_def

        self.generic_visit(node)
        self.subs.append(node)

    def visit_FuncDef(self, node):
        if node.coord.file != self.path:
            return

        self.func_def = True
        self.generic_visit(node)

    def get_ast(self):
        with open(self.path, "w") as f:
            f.write(self.orig)

        ast = parse_file(self.path, use_cpp = True, cpp_path = cross_gcc,
                         cpp_args = ["-E", "-I", "include", "-include", "predef.h"])
        # ast.show()

        unlink(self.path)
        return ast

    def make_old_new(self, sub):
        old = self.gen.visit(sub).replace("*" + sub.name, "* " + sub.name)

        old_type = sub.type
        sub.type = self.proto[sub.name].type

        new = self.gen.visit(sub)
        sub.type = old_type

        print("{} -> {}".format(old, new))
        return (old, new)

    def replace_dec(self):
        print("--- {}:{}:".format(self.path, self.this_line))

        lines = self.orig.split("\n")
        out = lines[:self.this_line - 1]

        for sub in self.subs:
            (_, new) = self.make_old_new(sub)

            if self.align:
                news_spa = new.split(" ")
                news_tab = []

                while news_spa[0] in ["extern", "char", "short", "int",
                                      "unsigned", "long", "struct", "void",
                                      "FILE", "PFS"]:
                    is_struct = news_spa[0] == "struct"

                    news_tab.append(news_spa[0])
                    news_spa = news_spa[1:]

                    if is_struct:
                        news_tab.append(news_spa[0])
                        news_spa = news_spa[1:]

                new_spa = " ".join(news_spa)
                new_tab = "\t".join(news_tab)
                new = new_tab + "\t" + new_spa

            out.append(new + ";")

        out += lines[self.this_line:]
        self.orig = "\n".join(out)

        self.this_line += len(self.subs) - 1

    def replace_def(self):
        print("--- {}:{}:".format(self.path, self.this_line))

        def make_re(olds, strip = None):
            escs = [escape(x) for x in olds if x != strip]
            re = "[\t\n ]*".join(escs) + "[^{/]*\n{"
            # print("re {}".format(re))
            return re

        for sub in self.subs:
            (old, new) = self.make_old_new(sub)

            olds = old.split(" ")
            new += "\n{"

            (new_orig, n) = subn(make_re(olds), new, self.orig)

            if n == 0 and "int" in olds:
                (new_orig, n) = subn(make_re(olds, "int"), new, self.orig)

            if n != 1:
                raise Exception("error while replacing definition")

            self.orig = new_orig

    def fix(self, path, orig):
        self.path = path
        self.orig = orig
        self.align = "extern\t" in orig or \
                "char\t" in orig or \
                "short\t" in orig or \
                "int\t" in orig or \
                "long\t" in orig

        while True:
            self.visit(self.get_ast())

            if self.this_line is None:
                break

            if self.sub_def:
                self.replace_def()
            else:
                self.replace_dec()

            self.prev_line = self.this_line
            self.this_line = None
            self.subs = []
            self.sub_def = False

        fixed = self.orig
        self.reset()
        return fixed

vis = Visitor()

with open("misc/c-files.txt", "r") as f:
    for path in f:
        path = path.rstrip()

        if path == "ram/wdfield.c": # breaks pycparser
            continue

        stdout.write("fixing {}                    \r".format(path))
        stdout.flush()

        with open(path, "r") as f:
            orig = f.read()

        segs = path.split("/")
        path_ = "/".join(segs[:-1] + ["_" + segs[-1]])

        fixed = vis.fix(path_, orig)

        with open(path, "w") as f:
            f.write(fixed)

    print("")
