#!/usr/bin/env python

from swiginac import *
import sfc
from sfc import *
from sfc.symbolic_utils import *
import sfc.common.options
from sfc.common.utilities import get_callable_name

# ----------------------------------------------- Scalar forms:

def constant_scalar(itg):
    """a(;) = \int 1 dx"""
    return numeric(1)

def L2_scalar(w, itg):
    """a(;w) = \int w^2 dx"""
    return inner(w, w)

def H1_semi_scalar(w, itg):
    """a(;w) = \int grad(w)^2 dx"""
    GinvT = itg.GinvT()
    Dw = grad(w, GinvT)
    return inner(Dw, Dw)

def H1_scalar(w, itg):
    """a(;w) = \int w^2 + grad(w)^2 dx"""
    GinvT = itg.GinvT()
    Dw = grad(w, GinvT)
    return inner(w, w) + inner(Dw, Dw)

scalar_forms = [constant_scalar, L2_scalar, H1_semi_scalar, H1_scalar]



# ----------------------------------------------- Vector forms:

def constant_vector(v, itg):
    """a(v;) = \int 1 dx"""
    return numeric(1)

def constant_source_vector(v, itg):
    """a(v;) = \int v dx"""
    return v

def source_vector(v, f, itg):
    """a(v; f) = \int f . v dx"""
    return inner(f, v)


vector_forms = [constant_vector, constant_source_vector, source_vector]



# ----------------------------------------------- Vector boundary forms:

def load_vector(v, t, itg):
    """a(v; t) = \int t . v dx"""
    return inner(t, v)


# ----------------------------------------------- Matrix forms:

def constant_matrix(v, u, itg):
    """a(v, u; ) = \int 1 dx"""
    return numeric(1)

def mass_matrix(v, u, itg):
    """a(v, u; ) = \int u . v dx"""
    return inner(v, u)

def mass_with_c_matrix(v, u, c, itg):
    """a(v, u; c) = \int c (u . v) dx"""
    return c * inner(v, u)

def stiffness_matrix(v, u, itg):
    GinvT = itg.GinvT()
    Du = grad(u, GinvT)
    Dv = grad(v, GinvT)
    return inner(Du, Dv)

def stiffness_with_M_matrix(v, u, M, itg):
    GinvT = itg.GinvT()
    Du = grad(u, GinvT)
    Dv = grad(v, GinvT)
    return inner(M * Du, Dv)

matrix_forms = [constant_matrix, mass_matrix, mass_with_c_matrix, stiffness_matrix, stiffness_with_M_matrix]



# ----------------------------------------------- Boundary matrix forms:

def mass_boundary_matrix(v, u, itg):
    """a(v, u; ) = \int u . v ds"""
    return inner(v, u)



# ----------------------------------------------- Testing:


if __name__ == "__main__":
    import sys

    sfc.common.options.add_debug_code = False

    sfc.common.options.print_options()

    args = set(sys.argv[1:])

    print_forms = True if "p" in args else False
    generate    = True if "g" in args else False
    compile     = True if "c" in args else False

    def check(form):
        form.sanity_check()
        if print_forms:
            print form
        if generate or compile:
            if compile:
                compiled_form = compile_form(form)
                print "Successfully compiled form:"
                print compiled_form
                print dir(compiled_form)
            else:
                res = write_ufc_code(form)
                print "Successfully generated form code:"
                print res

    quad_order = 3

    formcount = 0
    def form_name(callback):
        global formcount
        name = "form_%s_%d" % (get_callable_name(callback), formcount)
        formcount += 1
        return name

    for nsd in [2, 3]:
        polygon = { 2: "triangle", 3: "tetrahedron" }[nsd]
        print "Using polygon = ", polygon

        fe0 = FiniteElement("P0", polygon, 0)
        fe1 = FiniteElement("Lagrange", polygon, 1)
        fe2 = FiniteElement("Lagrange", polygon, 2)
        scalar_elements = [fe0, fe1, fe2]

        vfe0 = VectorElement("P0", polygon, 0)
        vfe1 = VectorElement("Lagrange", polygon, 1)
        vfe2 = VectorElement("Lagrange", polygon, 2)
        vector_elements = [vfe0, vfe1, vfe2]

        tfe0 = TensorElement("P0", polygon, 0)
        tfe1 = TensorElement("Lagrange", polygon, 1)
        tfe2 = TensorElement("Lagrange", polygon, 2)
        tensor_elements = [tfe0, tfe1, tfe2]
        
        # quicker, for debugging:
        scalar_elements = [fe1]
        vector_elements = [vfe1]
        tensor_elements = [tfe1]

        all_elements = scalar_elements + vector_elements + tensor_elements


        for symbolic in [False, True]:
            print "Using symbolic = ", symbolic

            options = { "symbolic": symbolic, "quad_order": quad_order }

            # creating scalar callback forms:

            print "Testing scalar forms"

            for fe in all_elements:

                callback = L2_scalar
                basisfunctions = []
                coefficients   = [Function(fe)]
                form = Form(name=form_name(callback), basisfunctions=basisfunctions, coefficients=coefficients, cell_integrands=[callback], options=options)
                check(form)
            
            for fe in scalar_elements + vector_elements:
                
                callback = H1_semi_scalar
                basisfunctions = []
                coefficients   = [Function(fe)]
                form = Form(name=form_name(callback), basisfunctions=basisfunctions, coefficients=coefficients, cell_integrands=[callback], options=options)
                check(form)
                
                callback = H1_scalar
                basisfunctions = []
                coefficients   = [Function(fe)]
                form = Form(name=form_name(callback), basisfunctions=basisfunctions, coefficients=coefficients, cell_integrands=[callback], options=options)
                check(form)


            # creating vector callback forms:

            print "Testing vector forms"

            for fe in all_elements:
                callback = constant_vector
                basisfunctions = [TestFunction(fe)]
                coefficients   = []
                form = Form(name=form_name(callback), basisfunctions=basisfunctions, coefficients=coefficients, cell_integrands=[callback], options=options)
                check(form)
                
            for fe in scalar_elements:
                callback = constant_source_vector
                basisfunctions = [TestFunction(fe)]
                coefficients   = []
                form = Form(name=form_name(callback), basisfunctions=basisfunctions, coefficients=coefficients, cell_integrands=[callback], options=options)
                check(form)
            
            for fe in scalar_elements:
                for f_fe in scalar_elements:
                    callback = source_vector
                    basisfunctions = [TestFunction(fe)]
                    coefficients   = [Function(f_fe)]
                    form = Form(name=form_name(callback), basisfunctions=basisfunctions, coefficients=coefficients, cell_integrands=[callback], options=options)
                    check(form)

                    callback = load_vector
                    basisfunctions = [TestFunction(fe)]
                    coefficients   = [Function(f_fe)]
                    form = Form(name=form_name(callback), basisfunctions=basisfunctions, coefficients=coefficients, exterior_facet_integrands=[callback], options=options)
                    check(form)

            for fe in vector_elements:
                for f_fe in vector_elements:
                    callback = source_vector
                    basisfunctions = [TestFunction(fe)]
                    coefficients   = [Function(f_fe)]
                    form = Form(name=form_name(callback), basisfunctions=basisfunctions, coefficients=coefficients, cell_integrands=[callback], options=options)
                    check(form)

                    callback = load_vector
                    basisfunctions = [TestFunction(fe)]
                    coefficients   = [Function(f_fe)]
                    form = Form(name=form_name(callback), basisfunctions=basisfunctions, coefficients=coefficients, exterior_facet_integrands=[callback], options=options)
                    check(form)

            for fe in tensor_elements:
                for f_fe in tensor_elements:
                    callback = source_vector
                    basisfunctions = [TestFunction(fe)]
                    coefficients   = [Function(f_fe)]
                    form = Form(name=form_name(callback), basisfunctions=basisfunctions, coefficients=coefficients, cell_integrands=[callback], options=options)
                    check(form)

                    callback = load_vector
                    basisfunctions = [TestFunction(fe)]
                    coefficients   = [Function(f_fe)]
                    form = Form(name=form_name(callback), basisfunctions=basisfunctions, coefficients=coefficients, exterior_facet_integrands=[callback], options=options)
                    check(form)


            # creating matrix callback forms:

            print "Testing matrix forms"

            for fe in all_elements:

                callback = constant_matrix
                basisfunctions = [TestFunction(fe), TrialFunction(fe)]
                coefficients   = []
                form = Form(name=form_name(callback), basisfunctions=basisfunctions, coefficients=coefficients, cell_integrands=[callback], options=options)
                check(form)
                
                callback = mass_matrix
                basisfunctions = [TestFunction(fe), TrialFunction(fe)]
                coefficients   = []
                form = Form(name=form_name(callback), basisfunctions=basisfunctions, coefficients=coefficients, cell_integrands=[callback], options=options)
                check(form)
                
                callback = mass_boundary_matrix
                basisfunctions = [TestFunction(fe), TrialFunction(fe)]
                coefficients   = []
                form = Form(name=form_name(callback), basisfunctions=basisfunctions, coefficients=coefficients, exterior_facet_integrands=[callback], options=options)
                check(form)
                
                for c_fe in scalar_elements:
                    callback = mass_with_c_matrix
                    basisfunctions = [TestFunction(fe), TrialFunction(fe)]
                    coefficients   = [Function(c_fe)]
                    form = Form(name=form_name(callback), basisfunctions=basisfunctions, coefficients=coefficients, cell_integrands=[callback], options=options)
                    check(form)
                
            for fe in scalar_elements + vector_elements:
                callback = stiffness_matrix
                basisfunctions = [TestFunction(fe), TrialFunction(fe)]
                coefficients   = []
                form = Form(name=form_name(callback), basisfunctions=basisfunctions, coefficients=coefficients, cell_integrands=[callback], options=options)
                check(form)
                
                for M_fe in scalar_elements + tensor_elements:
                    callback = stiffness_with_M_matrix
                    basisfunctions = [TestFunction(fe), TrialFunction(fe)]
                    coefficients   = [Function(M_fe)]
                    form = Form(name=form_name(callback), basisfunctions=basisfunctions, coefficients=coefficients, cell_integrands=[callback], options=options)
                    check(form)

