#!/usr/bin/python

import inspect
import os
import pdb
import copy
import sys
import json
import re
import jinja2
from optparse import OptionParser

# Django set up

import django
sys.path.append('.')
sys.path.append('/opt/xos')
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "xos.settings")
from django.db.models.fields.related import ForeignKey, ManyToManyField
from django.conf import settings
from django.contrib.contenttypes.models import ContentType

django.setup()

from core.models import PlCoreBase

options = None

def is_model_class(model):
    """ Return True if 'model' is something that we're interested in """
    if not inspect.isclass(model):
        return False
    if model.__name__ in ["PlModelMixIn"]:
        return False
    bases = inspect.getmro(model)
    bases = [x.__name__ for x in bases]
    if ("PlCoreBase" in bases) or ("PlModelMixIn" in bases):
        return True

    return False

def module_has_models(module):
    """ return True if 'module' contains any models we're interested in """
    for k in dir(module):
        v=getattr(module,k)
        if is_model_class(v):
            return True

    return False

def app_get_models_module(app):
    """ check whether 'app' includes XOS models """

    app = app + ".models"
    try:
        models_module = __import__(app)
    except ImportError:
        return False

    for part in app.split(".")[1:]:
        if module_has_models(models_module):
            return models_module
        models_module = getattr(models_module,part)

    if module_has_models(models_module):
        return models_module

    return None


def singular(foo, keys):
	for k in keys:
		if (foo==k+'es'):
			return k
		elif (foo==k+'s'):
			return k
	raise Exception('Plural to singular error for %s'%foo)

g = globals()

def enum_classes(apps):
    global app_map
    global class_map
    app_map = {}
    class_map = {}
    model_classes = []
    for app in apps:
            orig_app=app
            models_module = app_get_models_module(app)

            for classname in dir(models_module):
                    c = getattr(models_module, classname, None)

                    # For services, prevent loading of core models as it causes
                    # duplication.
                    if hasattr(c,"_meta") and hasattr(c._meta, "app_label"):
                        if (c._meta.app_label == "core") and (orig_app!="core"):
                            continue

                    if is_model_class(c) and c.__name__ not in options.blacklist:
                            model_classes.append(c)
                            app_map[c.__name__]=orig_app
                            c.class_name = c.__name__
                            file_name = c.__module__.rsplit('.',1)[1]
                            try:
                                if (file_name not in class_map[orig_app]):
                                    class_map[orig_app].append({file_name:[c]})
                                else:
                                    class_map[orig_app][file_name].append(c)

                            except KeyError:
                                class_map[orig_app] = [{file_name:[c]}]


    return model_classes

class GenObj(object):
	def __str__(self):
		return str(self.model.__name__.lower())

	def __init__(self, m):
		self.model = m
		self.props = []
		self.fields = []
                self.all_fields = []
		self.field_dict = []
		self.refs = []
                self.reverse_refs = []
		self.plural_name = None
                self.content_type_id = ContentType.objects.get_for_model(m).id

	def plural(self):
		if (self.plural_name):
			return self.plural_name
		else:
			name = str(self)
			if (name.endswith('s')):
				return name+'es'
			else:
				return name+'s'

        def singular(self):
            return str(self)

        def rest_name(self):
            # These are things that either for historic reasons or due to incorrect naming,
            # got called something different than the autogen thinks they should be
            # called.
            REST_FIXUP = {'controllernetworkses': 'controllernetworks',
                            'controllerimageses': 'controllerimages',
                            'controllersliceses': 'controllerslices',
                            'controlleruserses': 'controllerusers',
                            'sitedeploymentses': 'sitedeployments',
                            'siteroles': 'site_roles',
                            'sliceprivileges': 'slice_privileges',
                            'sliceroles': 'slice_roles',
                            }
            return REST_FIXUP.get(self.plural(), self.plural())

	def camel(self):
		name = str(self.model.__name__)
		return name
		
class Generator(dict):
        def __init__(self):
            self.apps = {}

	def all(self):
		return self.values()

        def rest_models(self):
                norest = [x.lower() for x in options.norest]
                return [v for v in self.values() if not (str(v) in norest)]
	
	def regex(self, r):
		filtered = filter(lambda o:re.match(r,str(o)), self.values())
		return filtered

	def add_object(self, o):
                global app_map
		obj = GenObj(o)
		fields = o._meta.fields
                try:
                    obj.app = app_map[o.__name__] # full name
                    if hasattr(o, "_meta") and hasattr(o._meta, "app_label"):
                        obj.app_name = o._meta.app_label
                    else:
                        obj.app_name = app_map[o.__name__].split(".")[-1]  # only the last part
                except KeyError:
                    print "KeyError: %r"%o.__name__
                obj.class_name = o.class_name

                file_name = o.__module__.rsplit('.',1)[1]

                try:
                    if (file_name not in self.apps[obj.app]):
                        self.apps[obj.app][file_name]=[obj]
                    else:
                        self.apps[obj.app][file_name].append(obj)

                except KeyError:
                    self.apps[obj.app] = {file_name:[obj]}

		self[str(obj).lower()]=obj

	def compute_links(self):
                base_props = [f.name for f in PlCoreBase._meta.fields]

		for obj in self.values():
			#if (str(obj)=='network'):
			#	pdb.set_trace()
			fields = obj.model._meta.fields
			for f in fields:
				if (f and f.rel):
					to_name = str(f.rel.to)
				else:
					to_name = None

				if type(f)==ForeignKey and to_name and to_name in self.keys():
					refobj = self[f.to_name]

					if (str(obj)=='slice' and f.to_name=='networks'):
						obj.refs.append(refobj)
					related_name = f.related_query_name()
					if (related_name!='+' and related_name.lower()!=str(obj).lower()):
						cobj = copy.deepcopy(obj)
						cobj.multi = True
						cobj.plural_name = related_name
						refobj.refs.append(cobj)
                                elif f.name.endswith("_ptr"):
                                        # django inherited model, for example HPCService
                                        # cause swagger and REST to break
                                        pass
				else:
                                        f.type = f.__class__.__name__
                                        if (type(f)==ForeignKey):
                                            f.related.model.class_name = f.related.model.__name__
                                        if (f.name not in base_props):
                                            obj.fields.append(f)
                                        obj.all_fields.append(f)
                                        obj.props.append(f.name)

			m2m = obj.model._meta.many_to_many
			for f in m2m:
				try:
					related_model_name = f.m2m_reverse_field_name()
				except:
					related_model_name = f.m2m_db_table().rsplit('_',1)[-1]

				related_name = f.related_query_name()
				if related_model_name in self.keys():
                                        #print "XXX1", obj, f, related_name, related_model_name
					refobj = self[related_model_name]
					cobj = copy.deepcopy(obj)
					cobj.multi=True
					refobj.refs.append(cobj)

                                # deal with upgradeFrom_rel_+
                                if (related_name.endswith("+")):
                                    continue

				if (related_name!='+') and related_model_name in self: # and related_name.lower()!=str(obj).lower()):
                                        refobj = self[related_model_name]
                                        #print "XXX2", obj, f, related_name, related_model_name, refobj.plural_name
					cobj = copy.deepcopy(refobj)
					cobj.multi = True

					obj.refs.append(cobj)

                for obj in self.values():
                        # generate foreign key reverse references
                        for f in obj.model._meta.related_objects:
                            related_model = getattr(f, "related_model", None)
                            if not f.related_name:
                                continue
                            if "+" in f.related_name:
                                continue
                            if related_model and (related_model.__name__.lower() in self.keys()):
                                cobj = copy.deepcopy(self[related_model.__name__.lower()])
                                cobj.related_name = f.related_name
                                obj.reverse_refs.append(cobj)

def main():
        global options
        parser = OptionParser(usage="modelgen [options] <template_fn>", )

        parser.add_option("-d", "--dict", dest="dict",
             help="dictionary to replace text in output", metavar="DICT", default=[], action="append")

        parser.add_option("-a", "--app", dest="apps",
             help="list of applications to parse", metavar="APP", default=[], action="append")
        parser.add_option("-b", "--blacklist", dest="blacklist",
             help="add model name to blacklist", metavar="MODEL", default=["SingletonModel", "PlCoreBase"], action="append")
        parser.add_option("-n", "--no-rest", dest="norest",
             help="do not generate rest api for model", metavar="MODEL", default=["SingletonModel", "PlCoreBase"], action="append")

        (options, args) = parser.parse_args(sys.argv[1:])

        template_name = os.path.abspath(args[0])

        # try to make sure we're running from the right place
        if (not os.path.exists("core")):
            if (os.path.exists("../core")):
                os.chdir("..")
            elif (os.path.exists("../../core")):
                os.chdir("../..")
            else:
                print >> sys.stderr, "Are you sure you're running modelgen from the root of an XOS installation"
                sys.exit(-1)

        if not options.apps:
            options.apps = ["core"]

        if options.apps == ["*"]:
            options.apps = [x for x in settings.INSTALLED_APPS if app_get_models_module(x)]

        if len(args)!=1:
            print 'Usage: modelgen [options] <template_fn>'
            exit(1)

	generator = Generator()

	models = enum_classes(options.apps)

	for m in models:
		generator.add_object(m)

	generator.compute_links()

        os_template_loader = jinja2.FileSystemLoader( searchpath=[os.path.split(template_name)[0]])
        os_template_env = jinja2.Environment(loader=os_template_loader)

        template = os_template_env.get_template(os.path.split(template_name)[1])
        rendered = template.render({"generator": generator})

        lines = rendered.splitlines()
        current_buffer = []
        for l in lines:
            if (l.startswith('+++')):
                filename = l[4:]
                fil = open(filename,'w')
                buf = '\n'.join(current_buffer)

                obuf = buf
                for d in options.dict:
                    df = open(d).read()
                    d = json.loads(df)

                    pattern = re.compile(r'\b(' + '|'.join(d.keys()) + r')\b')
                    obuf = pattern.sub(lambda x: d[x.group()], buf)
                fil.write(obuf)
                fil.close()

                print 'Written to file %s'%filename
                current_buffer = []
            else:
                current_buffer.append(l)
        if (current_buffer):
            print '\n'.join(current_buffer)


if (__name__=='__main__'):
	main()
