/*!
 *  \file KernelTargetAlignment.h
 *
 *  \brief basic error function
 *
 *  \author T.Voss, T. Glasmachers, O.Krause
 *  \date 2010-2011
 *
 *  \par Copyright (c) 1998-2007:
 *      Institut f&uuml;r Neuroinformatik<BR>
 *      Ruhr-Universit&auml;t Bochum<BR>
 *      D-44780 Bochum, Germany<BR>
 *      Phone: +49-234-32-25558<BR>
 *      Fax:   +49-234-32-14209<BR>
 *      eMail: Shark-admin@neuroinformatik.ruhr-uni-bochum.de<BR>
 *      www:   http://www.neuroinformatik.ruhr-uni-bochum.de<BR>
 *      <BR>
 *
 *
 *  <BR><HR>
 *  This file is part of Shark. This library is free software;
 *  you can redistribute it and/or modify it under the terms of the
 *  GNU General Public License as published by the Free Software
 *  Foundation; either version 3, or (at your option) any later version.
 *
 *  This library is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this library; if not, see <http://www.gnu.org/licenses/>.
 *
 */
#ifndef SHARK_OBJECTIVEFUNCTIONS_KERNELTARGETALIGNMENT_H
#define SHARK_OBJECTIVEFUNCTIONS_KERNELTARGETALIGNMENT_H

#include <shark/ObjectiveFunctions/DataObjectiveFunction.h>
#include <shark/Models/Kernels/AbstractKernelFunction.h>


namespace shark{

/*!
 *  \brief Implementation of the negative Kernel Target
 *  Alignment (KTA) as proposed by Nello Cristianini.
 *  This measure is extended for multi-class problems.
 *
 *  \par
 *  Kernel Target Alignment measures how well a kernel
 *  fits a classification training set. It is invariant
 *  under kernel rescaling. To turn it into an error
 *  function (i.e., to make minimization meaningful)
 *  the negative KTA is implemented.
 *
 *  \par
 *  The KTA has two important properties: It is
 *  differentiable and independent of the actual
 *  classifier.
 *
 *  \par
 *  The KTA as originally proposed by Nello Cristianini
 *  is not properly arranged for unbalanced datasets.
 *
 *  \par
 *  KTA measures the similarity, in terms of the inner
 *  product, of the kernel Gram matrix K with a perfect
 *  Gram matrix D with entries 1 or -1 for examples of
 *  coinciding or different label, respectively. Then
 *  the kernel target alignment is given by
 *  \f[
 *  	\hat A = \frac{\langle D, K \rangle}{\sqrt{\langle D, D \rangle \langle K, K\rangle}}
 *  \f]
 *  We generalize the measure by using the value -1/(N-1)
 *  for entries corresponding to different classes, which
 *  gives a canonical and symmetric generalization. Here,
 *  N denotes the number of classes.
 */
template<class InputType = RealVector>
class KernelTargetAlignment : public SupervisedObjectiveFunction<InputType, unsigned int>
{
public:
	typedef SupervisedObjectiveFunction<InputType,unsigned int> base_type;
	typedef typename base_type::SearchPointType SearchPointType;
	typedef typename base_type::ResultType ResultType;
	typedef typename base_type::FirstOrderDerivative FirstOrderDerivative;
	typedef typename base_type::SecondOrderDerivative SecondOrderDerivative;

	KernelTargetAlignment(AbstractKernelFunction<InputType>* kernel, unsigned int numberOfClasses){
		SHARK_CHECK(kernel != NULL, "[KernelTargetAlignment] kernel is not allowed to be NULL");
		
		m_kernel = kernel;
		m_numberOfClasses = numberOfClasses;
		m_offdiag = -1.0 / (numberOfClasses - 1.0);
		m_2offdiag = 2.0 * m_offdiag;
		m_offdiag2 = m_offdiag * m_offdiag;
		m_2offdiag2 = 2.0 * m_offdiag2;
		
		this->m_name = "KernelTargetAlignment";
		this->m_features|=base_type::HAS_VALUE;
		this->m_features|=base_type::CAN_PROPOSE_STARTING_POINT;
		
		if(m_kernel -> hasFirstParameterDerivative())
			this->m_features|=base_type::HAS_FIRST_DERIVATIVE;
	}

	void configure( const PropertyTree & node ){
		PropertyTree::const_assoc_iterator it = node.find("kernel");
		if(it != node.not_found()){
			m_kernel->configure(it->second);
		}
	}
	void setDataset(LabeledData<InputType, unsigned int> const& dataset){
		m_data = dataset;
	}

	void proposeStartingPoint(SearchPointType& startingPoint) const {
		startingPoint =  m_kernel -> parameterVector();
	}

	double eval(RealVector const& input) const{
		m_kernel->setParameterVector(input);
		
		double DD = 0.0;
		double DK = 0.0;
		double KK = 0.0;
		for (std::size_t i=0; i < m_data.size(); i++)
		{
			for (std::size_t j=0; j < i; j++)
			{
				double k = m_kernel->eval(m_data.input(i), m_data.input(j));
				if (m_data.label(i) == m_data.label(j))
				{
					DD += 2.0;
					DK += 2.0 * k;
				}
				else
				{
					DD += m_2offdiag2;
					DK += m_2offdiag * k;
				}
				KK += 2.0 * k * k;
			}
			double k = m_kernel->eval(m_data.input(i), m_data.input(i));
			DD += 1.0;
			DK += k;
			KK += k * k;
		}
		return ( - DK / std::sqrt(DD * KK));
	}
	ResultType evalDerivative( const SearchPointType & input, FirstOrderDerivative & derivative ) const {
		size_t parameters = m_kernel->numberOfParameters();
		derivative.m_gradient.resize(parameters);
		RealVector der1(parameters);
		RealVector der2(parameters);
		RealVector kernelDerivative(parameters);
		der1.clear();
		der2.clear();

		double DD = 0.0;
		double DK = 0.0;
		double KK = 0.0;
		for (std::size_t i=0; i < m_data.size(); i++)
		{
			for (std::size_t j=0; j < i; j++)
			{
				double k = (*m_kernel)(m_data.input(i), m_data.input(j),kernelDerivative);
				if (m_data.label(i) == m_data.label(j))
				{
					DD += 2.0;
					DK += 2.0 * k;
					der1 += 2.0 * kernelDerivative;
				}
				else
				{
					DD += m_2offdiag2;
					DK += m_2offdiag * k;
					der1 += m_2offdiag * kernelDerivative;
				}
				KK += 2.0 * k * k;
				der2 += 2.0 * k * kernelDerivative;
			}
			double k = (*m_kernel)(m_data.input(i), m_data.input(i),kernelDerivative);
			DD += 1.0;
			DK += k;
			KK += k * k;
			der1 += kernelDerivative;
			der2 += kernelDerivative;
		}

		double denom = sqrt(DD * KK);
		double f1 = 1.0 / denom;
		double f2 = DK / (KK * denom);
		noalias(derivative.m_gradient) = f2 * der2 - f1 * der1;

		return (-DK / denom);
	}
private:
	AbstractKernelFunction<InputType>* m_kernel;
	LabeledData<InputType,unsigned int> m_data;

	unsigned int m_numberOfClasses;
	double m_offdiag;
	double m_offdiag2;
	double m_2offdiag;
	double m_2offdiag2;
};

}
#endif
