#include <float.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include "mpi.h"
#include "jdmath.h"
#include "oxexport.h"
#include "../../oxmpi_codes.h"

#define OXMPI_VERBOSE  0x001

static int s_bMPI_Initialized = FALSE;
static int s_iMPI_IsMaster = -1;		/* -1: not set, 0: slave, 1: master */
static MPI_Comm s_iMPI_comm = MPI_COMM_WORLD;

static MPI_Op Ox2MPI_Operation(int opArg)
{
	switch (opArg)
	{
		case OxMPI_MAX	 	: return MPI_MAX;
		case OxMPI_MIN   	: return MPI_MIN;
		case OxMPI_SUM   	: return MPI_SUM;
		case OxMPI_PROD  	: return MPI_PROD;
		case OxMPI_LAND  	: return MPI_LAND;
		case OxMPI_BAND  	: return MPI_BAND;
		case OxMPI_LOR   	: return MPI_LOR;
		case OxMPI_BOR   	: return MPI_BOR;
		case OxMPI_LXOR  	: return MPI_LXOR;
		case OxMPI_BXOR  	: return MPI_BXOR;
		case OxMPI_MINLOC	: return MPI_MINLOC;
		case OxMPI_MAXLOC	: return MPI_MAXLOC;
		case OxMPI_REPLACE 	: return MPI_REPLACE;
		default				: return MPI_OP_NULL;
	};
}
/* Using long because these may be pointers.
   Not used at the moment
static long Ox2MPI_Other(int iArg)
{
	switch (iArg)
	{
		case OxMPI_COMM_WORLD 	: return (long)MPI_COMM_WORLD;
		case OxMPI_COMM_SELF  	: return (long)MPI_COMM_SELF;
		case OxMPI_GROUP_EMPTY	: return (long)MPI_GROUP_EMPTY;
	};
}
*/
static int MPI2Ox_Error(int iArg)
{
	switch (iArg)
	{
		case MPI_SUCCESS  : return OxMPI_SUCCESS  ;
		case MPI_ERR_TYPE : return OxMPI_ERR_TYPE ;
		case MPI_ERR_COMM : return OxMPI_ERR_COMM ;
		case MPI_ERR_RANK : return OxMPI_ERR_RANK ;
		case MPI_ERR_ROOT : return OxMPI_ERR_ROOT ;
		case MPI_ERR_GROUP: return OxMPI_ERR_GROUP;
		case MPI_ERR_OP   : return OxMPI_ERR_OP   ;
		default			  : return OxMPI_ERR_OTHER;
	};
}

static int Ox2MPI_Source(int iArg)
{
	switch (iArg)
	{
		case OxMPI_ANY_SOURCE : return (int)MPI_ANY_SOURCE;
		case OxMPI_PROC_NULL  : return (int)MPI_PROC_NULL;
		case OxMPI_UNDEFINED  : return (int)MPI_UNDEFINED;
		default			  	  : return iArg;
	};
}
static int MPI2Ox_Source(int iArg)
{
	switch (iArg)
	{
		case MPI_ANY_SOURCE : return OxMPI_ANY_SOURCE;
		case MPI_PROC_NULL  : return OxMPI_PROC_NULL;
		case MPI_UNDEFINED  : return OxMPI_UNDEFINED;
		default			  	: return iArg;
	};
}

static int Ox2MPI_Tag(int iArg)
{
	switch (iArg)
	{
		case OxMPI_ANY_TAG    : return (int)MPI_ANY_TAG;
		case OxMPI_UNDEFINED  : return (int)MPI_UNDEFINED;
		default			  	  : return iArg;
	};
}
static int MPI2Ox_Tag(int iArg)
{
	switch (iArg)
	{
		case MPI_ANY_TAG    : return OxMPI_ANY_TAG;
		case MPI_UNDEFINED  : return OxMPI_UNDEFINED;
		default			    : return iArg;
	};
}

static void OXCALL mpi_exit(void)
{
	if (s_bMPI_Initialized)
	{
		MPI_Barrier(s_iMPI_comm);
		MPI_Finalize();
	}
}
static void OXCALL oxlibvalstrmalloc(OxVALUE *pv, int c)
{
	OxFreeByValue(pv);

	if (c < 0)
        OxRunError(ER_SIZENEGATIVE, NULL);
	else if (c > 0)
    {	
		if ((pv->t.sval.data = (char *)malloc(sizeof(char) * (c + 1))) == NULL)
			OxRunError(ER_OM, NULL);
		pv->t.sval.data[c] = '\0';
	}
	else
		pv->t.sval.data = NULL;

    pv->t.sval.size = c;
    pv->type = OX_STRING | OX_VALUE;
}

void OXCALL FnMPI_SetMaster(OxVALUE *rtn, OxVALUE *pv, int cArg)
{
	OxLibCheckType(OX_INT, pv, 0, 0);
	if (OxInt(pv,0) > 0)
		s_iMPI_IsMaster = 1;
	else if (OxInt(pv,0) == 0)
		s_iMPI_IsMaster = 0;
	/* else leave unchanged */

	if (cArg > 1)
	{	OxLibCheckType(OX_INT, pv, 1, 1);
		if (!s_iMPI_IsMaster)
		{	if (OxInt(pv,1))
				OxSetPrintlevel(0);
		}
	}
}
void OXCALL FnMPI_IsMaster(OxVALUE *rtn, OxVALUE *pv, int cArg)
{
	OxInt(rtn,0) = s_iMPI_IsMaster == 1;
}

void OXCALL FnMPI_Init(OxVALUE *rtn, OxVALUE *pv, int cArg)
{
	int iret;
	
	if (rtn)											 /* rtn may be NULL */
		OxInt(rtn,0) = OxMPI_SUCCESS;
		
	if (s_bMPI_Initialized)				 /* if already done: don't do again */
		return;

#ifdef SKIP_MPI_Init
	/* assume already initialized in main */
	s_bMPI_Initialized = TRUE;
	OxRunMainExitCall(mpi_exit);	 /* call MPI_Finalize at end of Ox main */
	return;
#endif

#ifdef OXMPI_NO_ARGV
	iret = MPI_Init(NULL, NULL);	   /* MPICH2: don't pass Ox args to MPI */
#else
	{	int argc;  char **argv;
	
		OxGetMainArgs(&argc, &argv);  /* points to arguments for Ox program */
	
		iret = MPI_Init(&argc, &argv);	     /* MPI may have prepended args */
		
		OxSetMainArgs(argc, argv);			  /* and then remove them again */
	}
#endif
	
	iret = MPI2Ox_Error(iret);	/* translate error code to those used by Ox */
	if (rtn)
		OxInt(rtn,0) = iret;						 /* set the return code */

	if (iret == OxMPI_SUCCESS)
	{
		s_bMPI_Initialized = TRUE;
		OxRunMainExitCall(mpi_exit); /* call MPI_Finalize at end of Ox main */
	}
}	
void OXCALL FnMPI_Finalize(OxVALUE *rtn, OxVALUE *pv, int cArg)
{
	if (!s_bMPI_Initialized)
		return;

	MPI_Finalize();
	s_bMPI_Initialized = FALSE;
}	

void OXCALL FnMPI_Comm_size(OxVALUE *rtn, OxVALUE *pv, int cArg)
{
	int numprocs;
	if (!s_bMPI_Initialized)
		FnMPI_Init(NULL, NULL, 0);
    MPI_Comm_size(s_iMPI_comm, &numprocs);
    OxSetInt(rtn, 0, numprocs);
}	
void OXCALL FnMPI_Comm_rank(OxVALUE *rtn, OxVALUE *pv, int cArg)
{
	int myid;
	if (!s_bMPI_Initialized)
		FnMPI_Init(NULL, NULL, 0);
    MPI_Comm_rank(s_iMPI_comm, &myid);
    OxSetInt(rtn, 0, myid);
}	
void OXCALL FnMPI_Get_processor_name(OxVALUE *rtn, OxVALUE *pv, int cArg)
{
    int  namelen;
    char processor_name[MPI_MAX_PROCESSOR_NAME];
    MPI_Get_processor_name(processor_name, &namelen);
    OxValSetString(rtn, processor_name);
}	
void OXCALL FnMPI_Wtime(OxVALUE *rtn, OxVALUE *pv, int cArg)
{
    double wtime = MPI_Wtime();
	OxSetDbl(rtn, 0, wtime);
}	

void OXCALL FnMPI_Bcast(OxVALUE *rtn, OxVALUE *pv, int cArg)
{
	int  i, myid, root, header[3], iret;
	OxVALUE *apv0, apvarg2[2];
	double d;

	OxLibCheckType(OX_ARRAY, pv, 0, 0);
	OxLibCheckType(OX_INT, pv, 1, 1);
	apv0 = OxArray(pv,0);
	root = OxInt(pv,1);
	
    MPI_Comm_rank(s_iMPI_comm, &myid);

	if (myid == root)
	{
		header[0] = GETPVTYPE(apv0);
		switch (header[0])
		{
			case OX_INT:
			case OX_DOUBLE:
				break;
			case OX_MATRIX:
				header[1] = OxMatr(apv0, 0);
				header[2] = OxMatc(apv0, 0);
				break;
			case OX_STRING:
				header[1] = OxStrLen(apv0, 0);
				break;
			case OX_ARRAY:
				header[1] = OxArrayLen(apv0, 0);
				break;
			default:
			{
				char s[200];
				sprintf(s, "MPI_Bcast(): cannot broadcast type %s", SOxGetTypeName(header[0]));
				OxRunErrorMessage(s);
			}
		}
		iret = MPI_Bcast(header, 3, MPI_INT, root, s_iMPI_comm);
	}
	else
	{
		iret = MPI_Bcast(header, 3, MPI_INT, root, s_iMPI_comm);
	}
	OxInt(rtn,0) = MPI2Ox_Error(iret);
	
	if (iret != MPI_SUCCESS)
		return;

	if (myid == root)
	{
		switch (header[0])
		{
			case OX_INT:
				i = OxInt(apv0, 0);
				iret = MPI_Bcast(&i, 1, MPI_INT, root, s_iMPI_comm);
				break;
			case OX_DOUBLE:
				d = OxDbl(apv0, 0);
				iret = MPI_Bcast(&d, 1, MPI_DOUBLE, root, s_iMPI_comm);
				break;
			case OX_MATRIX:
				iret = MPI_Bcast(OxMat(apv0,0)[0], header[1] * header[2], MPI_DOUBLE, root, s_iMPI_comm);
				break;
			case OX_STRING:
				iret = MPI_Bcast(OxStr(apv0,0), header[1], MPI_CHAR, root, s_iMPI_comm);
				break;
			case OX_ARRAY:
				apvarg2[1] = pv[1];
				for (i = 0; i < header[1]; ++i)
				{	
					OxSetAddress(apvarg2, 0, OxArrayData(apv0) + i);
					FnMPI_Bcast(rtn, apvarg2, 2);
					iret = OxInt(rtn,0);
				}
				break;
		}
	}
	else
	{
		switch (header[0])
		{
			case OX_INT:
				iret = MPI_Bcast(&i, 1, MPI_INT, root, s_iMPI_comm);
				OxSetInt(apv0, 0, i);
				break;
			case OX_DOUBLE:
				iret = MPI_Bcast(&d, 1, MPI_DOUBLE, root, s_iMPI_comm);
				OxSetDbl(apv0, 0, d);
				break;
			case OX_MATRIX:
				OxValSetMatSize(apv0, header[1], header[2]);
				iret = MPI_Bcast(OxMat(apv0,0)[0], header[1] * header[2], MPI_DOUBLE, root, s_iMPI_comm);
				break;
			case OX_STRING:
				oxlibvalstrmalloc(apv0, header[1]);
				iret = MPI_Bcast(OxStr(apv0,0), header[1], MPI_CHAR, root, s_iMPI_comm);
				break;
			case OX_ARRAY:
				OxValSetArray(apv0, header[1]);

				apvarg2[1] = pv[1];
				for (i = 0; i < header[1]; ++i)
				{	
					OxSetAddress(apvarg2, 0, OxArrayData(apv0) + i);
					FnMPI_Bcast(rtn, apvarg2, 2);
					iret = OxInt(rtn,0);
				}
				break;
		}
	}
	OxInt(rtn,0) = MPI2Ox_Error(iret);
}

void OXCALL FnMPI_Reduce(OxVALUE *rtn, OxVALUE *pv, int cArg)
{
	int iret, root;
	
	OxLibCheckType(OX_INT, pv, 1, 2);
	root = OxInt(pv, 2);
	switch (GETPVTYPE(pv))
	{
		case OX_INT:
		    OxSetInt(rtn, 0, 0);
			iret = MPI_Reduce((void*)&(OxInt(pv,0)), (void*)&(OxInt(rtn,0)), 1, MPI_INT, Ox2MPI_Operation(OxInt(pv,1)), root, s_iMPI_comm);
		case OX_DOUBLE:
		    OxSetDbl(rtn, 0, 0);
			iret = MPI_Reduce((void*)&(OxDbl(pv,0)), (void*)&(OxDbl(rtn,0)), 1, MPI_DOUBLE, Ox2MPI_Operation(OxInt(pv,1)), root, s_iMPI_comm);
			break;
		case OX_MATRIX:
		{
			int r = OxMatr(pv,0), c = OxMatc(pv,0);
			OxValSetMatSize(rtn, r, c);
			iret = MPI_Reduce((void*)OxMat(pv,0)[0], (void*)OxMat(rtn,0)[0], r * c, MPI_DOUBLE, Ox2MPI_Operation(OxInt(pv,1)), root, s_iMPI_comm);
			break;
		}			
		default:
		{
			char s[200];
			sprintf(s, "MPI_Reduce(): cannot reduce type %s", SOxGetTypeName(GETPVTYPE(pv)));
			OxRunErrorMessage(s);
		}
	}
	OxInt(rtn,0) = MPI2Ox_Error(iret);
}

void OXCALL FnMPI_Iprobe(OxVALUE *rtn, OxVALUE *pv, int cArg)
{
	int flag, iret;
	MPI_Status status;
	OxLibCheckType(OX_INT, pv, 0, 1);

	iret = MPI_Iprobe(Ox2MPI_Source(OxInt(pv,0)), Ox2MPI_Tag(OxInt(pv,1)), s_iMPI_comm, &flag, &status);

	if (iret == MPI_SUCCESS && flag)
	{
		OxValSetMatSize(rtn, 1, 3);
		OxMat(rtn,0)[0][0] = MPI2Ox_Source(status.MPI_SOURCE);
		OxMat(rtn,0)[0][1] = MPI2Ox_Tag(status.MPI_TAG);
		OxMat(rtn,0)[0][2] = MPI2Ox_Error(status.MPI_ERROR);
	}
	else
		OxValSetMatSize(rtn, 0, 0);
}
void OXCALL FnMPI_Probe(OxVALUE *rtn, OxVALUE *pv, int cArg)
{
	int iret;
	MPI_Status status;
	OxLibCheckType(OX_INT, pv, 0, 1);

	iret = MPI_Probe(Ox2MPI_Source(OxInt(pv,0)), Ox2MPI_Tag(OxInt(pv,1)), s_iMPI_comm, &status);

	if (iret == MPI_SUCCESS)
	{
		OxValSetMatSize(rtn, 1, 3);
		OxMat(rtn,0)[0][0] = MPI2Ox_Source(status.MPI_SOURCE);
		OxMat(rtn,0)[0][1] = MPI2Ox_Tag(status.MPI_TAG);
		OxMat(rtn,0)[0][2] = MPI2Ox_Error(status.MPI_ERROR);
	}
	else
		OxValSetMatSize(rtn, 0, 0);
}
void OXCALL FnMPI_Send(OxVALUE *rtn, OxVALUE *pv, int cArg)
{
	int dest, tag, aisend[3], len;
	
	OxLibCheckType(OX_INT, pv, 1, 2);
	dest = OxInt(pv,1);
	tag = OxInt(pv,2);

	aisend[0] = GETPVTYPE(pv);
	switch (GETPVTYPE(pv))
	{
		case OX_INT:
			aisend[1] = OxInt(pv,0);
			MPI_Send(aisend, 3, MPI_INT, dest, tag, s_iMPI_comm);
			return;		/* finished */
		case OX_DOUBLE:
			len = 1;
			break;
		case OX_MATRIX:
			aisend[1] = OxMatr(pv,0);
			aisend[2] = OxMatc(pv,0);
			len = aisend[1] * aisend[2];
			break;
		case OX_STRING:
			len = aisend[1] = OxStrLen(pv,0);
			break;
		case OX_ARRAY:
			len = aisend[1] = OxArraySize(pv);
			break;
		default:
			return;
	}
	MPI_Send(aisend, 3, MPI_INT, dest, tag, s_iMPI_comm);

	if (len)
	{
		switch (GETPVTYPE(pv))
		{
			case OX_DOUBLE:
				MPI_Send(&(OxDbl(pv,0)), 1, MPI_DOUBLE, dest, tag, s_iMPI_comm);
				break;
			case OX_MATRIX:
				MPI_Send(OxMat(pv,0)[0], len, MPI_DOUBLE, dest, tag, s_iMPI_comm);
				break;
			case OX_STRING:
				MPI_Send(OxStr(pv,0), len, MPI_CHAR, dest, tag, s_iMPI_comm);
				break;
			case OX_ARRAY:
			{
				int i;  OxVALUE pvarg[3];
				pvarg[1] = pv[1];									/* dest */
				pvarg[2] = pv[2];									 /* tag */
				for (i = 0; i < len; ++i)
				{
					pvarg[0] = OxArrayData(pv)[i];			 /* array entry */ 
					FnMPI_Send(rtn, pvarg, 3);
				}
				break;
			}
		}
	}
}
void OXCALL FnMPI_Recv(OxVALUE *rtn, OxVALUE *pv, int cArg)
{
	int source, tag, airecv[3], len, i;
	MPI_Status status;
	
	OxLibCheckType(OX_INT, pv, 0, 1);
	source = OxInt(pv,0);
	tag = OxInt(pv,1);

	MPI_Recv(airecv, 3, MPI_INT, source, tag, s_iMPI_comm, &status);

	len = airecv[1];
	switch (airecv[0])
	{
		case OX_INT:
			OxSetInt(rtn, 0, airecv[1]);
			return;		/* finished */
		case OX_DOUBLE:
			MPI_Recv(&OxDbl(rtn,0), 1, MPI_DOUBLE, source, tag, s_iMPI_comm, &status);
			rtn->type = OX_DOUBLE;
			break;
		case OX_MATRIX:
			OxValSetMatSize(rtn, airecv[1], airecv[2]);
			len = OxMatrc(rtn,0);
			if (len)
				MPI_Recv(OxMat(rtn,0)[0], len, MPI_DOUBLE, source, tag, s_iMPI_comm, &status);
			break;
		case OX_STRING:
			oxlibvalstrmalloc(rtn, len);
			if (len)
			{	MPI_Recv(rtn->t.sval.data, len, MPI_CHAR, source, tag, s_iMPI_comm, &status);
				rtn->t.sval.data[len] = '\0';
			}
			break;
		case OX_ARRAY:
			OxValSetArray(rtn, len);
			for (i = 0; i < len; ++i)
				FnMPI_Recv(OxArrayData(rtn) + i, pv, cArg);
			break;
	}
}
void OXCALL FnMPI_Barrier(OxVALUE *rtn, OxVALUE *pv, int cArg)
{
	MPI_Barrier(s_iMPI_comm);
}
