#!/usr/bin/env python

# Copyright (C) 2012 Anders Logg
#
# This file is part of DOLFIN.
#
# DOLFIN is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# DOLFIN is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with DOLFIN. If not, see <http://www.gnu.org/licenses/>.
#
# First added:  2012-03-09
# Last changed: 2012-04-25

from dolfin import *
from time import time

SIZE = 32
NUM_REPS = 3

print "Solving the linear system for Poisson's equation using PETSc CUSP"

# Check that we have the PETSc Cusp backend
if not has_linear_algebra_backend("PETScCusp"):
    print "Cannot run this benchmark since PETSc Cusp is not available."
    print "BENCH: 0.0"
    exit(0)

# Function for running test
def run_bench(linear_algebra_backend):

    info("")
    info("Linear algebra backend: %s" % linear_algebra_backend)

    # Set linear algebra backend
    parameters["linear_algebra_backend"] = linear_algebra_backend

    # Create matrix and vector
    print "Assembling matrix and vector"
    mesh = UnitCube(SIZE, SIZE, SIZE)
    V = FunctionSpace(mesh, "Lagrange", 1)
    u = TrialFunction(V)
    v = TestFunction(V)
    f = Constant(1.0)
    a = dot(grad(u), grad(v))*dx
    L = f*v*dx
    bc = DirichletBC(V, 0.0, DomainBoundary())
    A, b = assemble_system(a, L, bc)

    # Create linear solver
    solver = KrylovSolver("cg", "jacobi")

    # Use hack to get around PETSc Cusp bug
    solver.parameters["use_petsc_cusp_hack"] = True

    # Solve linear system
    info("Solving linear system %d times" % NUM_REPS)
    x = Vector()
    cpu_time = time()
    for i in range(NUM_REPS):
        x.zero()
        solver.solve(A, x, b)
        print "residual =", residual(A, x, b)
    cpu_time = (time() - cpu_time) / float(NUM_REPS)

    return cpu_time

# Run benchmarks
cpu_time_petsc_cusp = run_bench("PETScCusp")
cpu_time_petsc      = run_bench("PETSc")

# Compute speedup
speedup = cpu_time_petsc / cpu_time_petsc_cusp

# Report results
print
print "PETSc:     ", cpu_time_petsc
print "PETSc Cusp:", cpu_time_petsc_cusp
print "Speedup:   ", speedup
print

print "BENCH: ", speedup
