//===================================================================
// matrix.hpp
//
// Version 1.0
//
// Written by:
//   Brent Worden
//   WordenWare
//   email:  Brent@Worden.org
//
// Copyright (c) 1999 WordenWare
//
// Created:  April 10, 1999
// Revised:  
//===================================================================

#ifndef _MATRIX_HPP_
#define _MATRIX_HPP_

#include <algorithm>
#include <cstdlib>
#include <functional>
#include <numeric>
#include <cmath>

#include "funcobj.hpp"
#include "vector.hpp"

NUM_BEGIN

template<class T>
class Matrix
//-------------------------------------------------------------------
// mathematical matrix object
//-------------------------------------------------------------------
{
public:
	typedef Vector<T>* iterator;
	typedef const Vector<T>* const_iterator;
	typedef size_t size_type;

	Matrix()
	//---------------------------------------------------------------
	// Create an empty matrix.
	//---------------------------------------------------------------
	: _row_capacity(0), _col_capacity(0), _rows(0), _cols(0), _data(NULL) {}

	Matrix(size_type m, size_type n)
	//---------------------------------------------------------------
	// Create a m x n matrix
	//---------------------------------------------------------------
	: _row_capacity(0), _col_capacity(0), _rows(m), _cols(n)
	{
		_data = _allocate(m, n, T(0));
	}

	Matrix(const Matrix<T>& other)
	//---------------------------------------------------------------
	// Create a matrix by copying other.
	//---------------------------------------------------------------
	{
		size_type r = other.rows();
		size_type c = other.columns();
		_row_capacity = r;
		_col_capacity = c,
		_rows = r;
		_cols = c,
		_data = _allocate(rowCapacity(), columnCapacity(), T(0));
		std::copy(other.row_begin(), other.row_end(), _data);
	}

	Matrix(const_iterator first, const_iterator last)
	//---------------------------------------------------------------
	// Create a matrix with initial rows [first, last).
	//---------------------------------------------------------------
	: _rows(_length(first, last)), _cols(first->size())
	{
		size_type r = _length(first, last);
		size_type c = first->size();
		_row_capacity = r;
		_col_capacity = c,
		_rows = r;
		_cols = c,
		_data = _allocate(rowCapacity(), columnCapacity(), T(0));
		std::copy(first, last, _data);
	}

	~Matrix()
	//---------------------------------------------------------------
	// Destroy this matrix.
	//---------------------------------------------------------------
	{ _destroy(); }

	const_iterator row_begin() const
	//---------------------------------------------------------------
	// An iterator positioned at the first row of this matrix.
	//---------------------------------------------------------------
	{ return (const_iterator)_data; }

	const_iterator row_end() const
	//---------------------------------------------------------------
	// An interator positioned one position past the last row of this
	// matrix.
	//---------------------------------------------------------------
	{ return (const_iterator)(_data + _rows); }

	iterator row_begin()
	//---------------------------------------------------------------
	// An iterator positioned at the first row of this matrix.
	//---------------------------------------------------------------
	{ return _data; }

	iterator row_end()
	//---------------------------------------------------------------
	// An interator positioned one position past the last row of this
	// matrix.
	//---------------------------------------------------------------
	{ return (_data + _rows); }

	size_type rows() const
	//---------------------------------------------------------------
	// Access the number of rows of this matrix.
	//---------------------------------------------------------------
	{ return _rows; }

	size_type columns() const
	//---------------------------------------------------------------
	// Access the number of columns of this matrix.
	//---------------------------------------------------------------
	{ return _cols; }

	size_type rowCapacity() const
	//---------------------------------------------------------------
	// Access the capcity of rows of this matrix.
	//---------------------------------------------------------------
	{ return _row_capacity; }

	size_type columnCapacity() const
	//---------------------------------------------------------------
	// Access the number of columns of this matrix.
	//---------------------------------------------------------------
	{ return _col_capacity; }

	const Matrix<T>& operator=(const Matrix<T>& rhs)
	//---------------------------------------------------------------
	// Assign rhs to this matrix.
	//---------------------------------------------------------------
	{ 
		if(this != &rhs){
			size_type r = rhs.rows();
			size_type c = rhs.columns();
			if(r <= rowCapacity()){
				std::copy(rhs.row_begin(), rhs.row_end(), _data);
			} else {
				_destroy();
				_data = _allocate(r, c, T(0));
				std::copy(rhs.row_begin(), rhs.row_end(), _data);
				_row_capacity = r;
				_col_capacity = c;
			}
			_rows = r;
			_cols = c;
		}
		return *this;
	}

	const Matrix<T>& operator+=(const Matrix<T>& rhs)
	//---------------------------------------------------------------
	// Assign the matrix sum of this matrix and rhs to this matrix.
	//---------------------------------------------------------------
	{
		if(rows() != rhs.rows() || columns() != rhs.columns()){
			throw Exception("Matrix<T>::operator+=", "Incompatible matrix sizes");
		}
		std::transform(row_begin(), row_end(), rhs.row_begin(), row_begin(), std::plus<Vector<T> >());
		return *this;
	}

	const Matrix<T>& operator-=(const Matrix<T>& rhs)
	//---------------------------------------------------------------
	// Assign the matrix difference of this matrix and rhs to this
	// matrix.
	//---------------------------------------------------------------
	{
		if(rows() != rhs.rows() || columns() != rhs.columns()){
			throw Exception("Matrix<T>::operator-=", "Incompatible matrix sizes");
		}
		std::transform(row_begin(), row_end(), rhs.row_begin(), row_begin(), std::minus<Vector<T> >());
		return *this;
	}

	const Matrix<T>& operator*=(T scale)
	//---------------------------------------------------------------
	// Assign the scalar procuct of this matrix and scale to this
	// matrix.
	//---------------------------------------------------------------
	{
		std::transform(row_begin(), row_end(), row_begin(), ScaleValue<Vector<T> >(scale));
		return *this;
	}

	Vector<T>& operator[](size_type index)
	//---------------------------------------------------------------
	// Access the index-th row of this matrix.
	//---------------------------------------------------------------
	{
		if(index < 0 || index >= rows()){
			throw Exception("Matrix<T>::operator[]", "invalid index parameter");
		}
		return *(_data + index);
	}

	void resize(size_type r, size_type c)
	//---------------------------------------------------------------
	// Resize this matrix to r x c.
	//---------------------------------------------------------------
	{
		if(r > rowCapacity() || c > columnCapacity()){
			iterator tmp = _allocate(r, c, T(0));
			std::copy(_data, _data + rows(), tmp);
			_destroy();
			_data = tmp;
			_row_capacity = r;
			_col_capacity = c;
		}
		_rows = r;
		_cols = c;
	}

private:
	iterator _allocate(size_type r, size_type c, const T& init = T(0))
	//---------------------------------------------------------------
	// Create r x c elements all initialized to init.
	//---------------------------------------------------------------
	{
		iterator ret = new Vector<T>[r];
		for(iterator p = ret; p != ret + r; ++p){
			p->resize(c);
			std::fill(p->begin(), p->end(), init);
		}
		return ret;
	}

	void _destroy()
	//---------------------------------------------------------------
	// Destroy all the elements in this Vector.
	//---------------------------------------------------------------
	{
		delete [] _data;
		_row_capacity = 0;
		_rows = 0;
		_col_capacity = 0;
		_cols = 0;
		_data = NULL;
	}

	size_type _length(const_iterator first, const_iterator last)
	//---------------------------------------------------------------
	// Return the number of elements in [first, last).
	//---------------------------------------------------------------
	{
		size_type ret;
		for(ret = 0; first != last; ++ret, ++first);
		return ret;
	}

	size_type _row_capacity;
	size_type _col_capacity;
	size_type _rows;         // number of rows
	size_type _cols;         // number of columns
	Vector<T>* _data;  // matrix rows
};

template<class T>
Matrix<T> operator+(const Matrix<T>& lhs, const Matrix<T>& rhs)
//-------------------------------------------------------------------
// Matrix sum of lhs and rhs.
//-------------------------------------------------------------------
{
	if((lhs.columns() != rhs.columns()) ||
	   (lhs.rows() != rhs.rows())){
		throw exception("operator+(Matrix<T>, Vector<T>)", "Incompatible sizes");
	}
	Matrix<T> ret(lhs);
	return ret += rhs;
}

template<class T>
Matrix<T> operator-(const Matrix<T>& lhs, const Matrix<T>& rhs)
//-------------------------------------------------------------------
// Matrix sum of lhs and rhs.
//-------------------------------------------------------------------
{
	if((lhs.columns() != rhs.columns()) ||
	   (lhs.rows() != rhs.rows())){
		throw exception("operator-(Matrix<T>, Vector<T>)", "Incompatible sizes");
	}
	Matrix<T> ret(lhs);
	return ret -= rhs;
}

template<class T>
Matrix<T> operator*(const Matrix<T>& lhs, const T& rhs)
//-------------------------------------------------------------------
// Scalar product of lhs and rhs.
//-------------------------------------------------------------------
{
	Matrix<T> ret(lhs);
	return ret *= rhs;
}

template<class T>
Matrix<T> operator*(const T& lhs, const Matrix<T>& rhs)
//-------------------------------------------------------------------
// Scalar product of lhs and rhs.
//-------------------------------------------------------------------
{
	return rhs * lhs;
}

template<class T>
Matrix<T> operator*(const Matrix<T>& lhs, const Matrix<T>& rhs)
//-------------------------------------------------------------------
// Dot product of lhs and rhs.
//-------------------------------------------------------------------
{
	if(lhs.columns() != rhs.rows()){
		throw exception("operator*(Matrix<T>, Matrix<T>)", "Incompatible matrix sizes");
	}
	Matrix<T>::size_type r = lhs.rows();
	Matrix<T>::size_type c = rhs.columns();
	Matrix<T> ret(r, c);
	Matrix<T> t;
	rhs.transpose(t);
	Matrix<T>::iterator liter = lhs.row_begin();
	Matrix<T>::iterator riter = ret.row_begin();
	while(liter != lhs.row_end()){
		Matrix<T>::iterator titer = t.row_begin();
		Vector<T>::iterator iter = riter.begin();
		while(titer != t.row_end()){
			*iter = (*liter) * (*titer);
			++iter;
			++titer;
		}
		++riter;
		++liter;
	}
	return ret;
}

template<class T>
Vector<T> operator*(const Vector<T>& lhs, const Matrix<T>& rhs)
//-------------------------------------------------------------------
// Dot product of lhs and rhs.  lhs is treated as a row vector.
//-------------------------------------------------------------------
{
	if(lhs.size() != rhs.rows()){
		throw exception("operator*(Vector<T>, Matrix<T>)", "Incompatible sizes");
	}
	return rhs.transpose() * lhs;
}

template<class T>
Vector<T> operator*(const Matrix<T>& lhs, const Vector<T>& rhs)
//-------------------------------------------------------------------
// Dot product of lhs and rhs.  rhs is treated as a column vector.
//-------------------------------------------------------------------
{
	if(lhs.columns() != rhs.size()){
		throw exception("operator*(Matrix<T>, Vector<T>)", "Incompatible sizes");
	}
	Vector<T> ret(lhs.rows());
	Matrix<T>::iterator liter = lhs.row_begin();
	Vector<T>::iterator riter = ret.begin();
	while(riter != ret.end()){
		*riter = rhs * (*liter);
		++riter;
		++liter;
	}
	return ret;
}

NUM_END

#endif

//===================================================================
// Revision History
//
// Version 1.0 - 04/10/1999 - New.
//===================================================================
