/*
 * <<< mpi_local.c >>>
 *
 * --- OSIRIS implementation of MPI subset 'mpi_local.c'
 *     Copyright (C) 2000 Amano Lab., Keio University. ---
 *
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 2 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License along
 *  with this program; if not, write to the Free Software Foundation, Inc.,
 *  59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
 */

/* #define __MPI_LOCAL_DEBUG */

#include <stdlib.h>
#include <sys/types.h>
#include "mpi.h"
#include "mpi_local.h"
#include "osiris.h"
#include "mpni_io_funcs.h"
#ifdef __MPI_LOCAL_DEBUG
# include <stdio.h>
#endif /* __MPI_LOCAL_DEBUG */

static __MPI_buffer_header* mpi_local_new_buffer(void);
static void mpi_local_delete_buffer(__MPI_buffer_header*);
static void mpi_local_enqueue_to_sendqueue(void*, int, int, int);
static void mpi_local_send_from_sendqueue(void);
static void mpi_local_receive_to_recvhash(void);
static int mpi_local_fetch_from_recvhash(void*, int, int, int, int, int, int*,
										 int*);

const size_t hashsize = 7;
static int mpi_local_init_flag = 0;
static size_t max_packet_size, max_data_size;
static __MPI_buffer_header* sendqueue_head;
static __MPI_buffer_header* sendqueue_tail;
static __MPI_buffer_header** recvhash;
static __MPI_buffer_header* freelist;

static __MPI_buffer_header* mpi_local_new_buffer(void)
{
	__MPI_buffer_header* p;
	if (freelist == NULL) {
		p = (__MPI_buffer_header*)
			malloc(max_packet_size + sizeof(__MPI_buffer_header));
#ifdef __MPI_LOCAL_DEBUG
		if (p == NULL) {
			fprintf(stderr,
					"mpi_local_new_buffer(pu%d): "
					"cannot allocate memory, aborted.\n",
					get_puid());
			fflush(stderr);
			exit(1);
		}
#endif /* __MPI_LOCAL_DEBUG */
	} else {
		p = freelist;
		freelist = ((__MPI_buffer_header*)(p))->next;
	}
	return p;
}

static void mpi_local_delete_buffer(__MPI_buffer_header* p)
{
	p->next = freelist;
	freelist = p;
}

static void mpi_local_enqueue_to_sendqueue
	(void* buf, int size, int dst, int tag)
{
	size_t total_packet_number, i;
	total_packet_number = (size + max_data_size - 1) / max_data_size;
#ifdef __MPI_LOCAL_DEBUG
	fprintf(stderr,
			"mpi_local_enqueue_to_sendqueue(pu%d): data size is %d bytes, "
			"total packet number is %d\n",
			get_puid(), size, total_packet_number);
	fflush(stderr);
#endif /* __MPI_LOCAL_DEBUG */
	for (i = 0; i < total_packet_number; i++) {
		__MPI_buffer_header* bufhead;
		__MPI_packet_header* pkthead;
		size_t packet_size, data_size;
#ifdef __MPI_LOCAL_DEBUG
		fprintf(stderr,
				"mpi_local_enqueue_to_sendqueue(pu%d): "
				"enqueue to sendqueue(%d/%d)\n",
				get_puid(), i, total_packet_number);
		fflush(stderr);
#endif /* __MPI_LOCAL_DEBUG */
		if (i < total_packet_number - 1) {
			data_size = max_data_size;
		} else {
			data_size = size % max_data_size;
			if (data_size == 0) data_size = max_data_size;
		}
		packet_size = (sizeof(__MPI_packet_header) + data_size + 15)
					  & ~(size_t)(0xf);
		bufhead = mpi_local_new_buffer();
		pkthead = (__MPI_packet_header*)
				  ((void*)bufhead + sizeof(__MPI_buffer_header));
		bufhead->next = NULL;
		bufhead->address = dst;
		bufhead->size = packet_size;
		pkthead->total = (size_t)(total_packet_number);
		pkthead->index = i;
		pkthead->size = data_size;
		pkthead->tag = tag;
		memcpy((void*)pkthead + sizeof(__MPI_packet_header),
			   (char*)(buf) + i * max_data_size, data_size);
		if (sendqueue_tail == NULL) {
			sendqueue_tail = sendqueue_head = bufhead;
		} else {
			sendqueue_tail->next = bufhead;
			sendqueue_tail = bufhead;
		}
	}
}

static void mpi_local_send_from_sendqueue(void)
{
	/* send all enqueued packets in sendqueue, if available */
	__MPI_buffer_header* bufhead = sendqueue_head;
	if (bufhead == NULL) {
#ifdef __MPI_LOCAL_DEBUG
		fprintf(stderr,
				"mpi_local_send_from_sendqueue(pu%d): sendqueue is empty\n",
				get_puid());
		fflush(stderr);
#endif /* __MPI_LOCAL_DEBUG */
		return;
	}
#ifdef __MPI_LOCAL_DEBUG
	fprintf(stderr,
			"mpi_local_send_from_sendqueue(pu%d): "
			"send all enqueued packets in sendqueue...\n",
			get_puid());
	fflush(stderr);
#endif /* __MPI_LOCAL_DEBUG */
	while (bufhead != NULL) {
		__MPI_packet_header* pkthead;
		size_t dst, data_size;
		if (__mpni_io_sense_send_queue() == 0) {
#ifdef __MPI_LOCAL_DEBUG
			fprintf(stderr,
					"mpi_local_send_from_sendqueue(pu%d): "
					"network interface is busy. abort...\n", get_puid());
			fflush(stderr);
#endif /* __MPI_LOCAL_DEBUG */
			break;
		}
		pkthead = (__MPI_packet_header*)
				  ((void*)bufhead + sizeof(__MPI_buffer_header));
		dst = bufhead->address;
		data_size = bufhead->size;
#ifdef __MPI_LOCAL_DEBUG
		fprintf(stderr,
				"mpi_local_send_from_sendqueue(pu%d): send packet to network, "
				"(dst:%d,size:%d,idx:%d/%d,tag:%d)\n",
				get_puid(), dst, data_size,
				pkthead->index, pkthead->total, pkthead->tag);
		fflush(stderr);
#endif /* __MPI_LOCAL_DEBUG */
		__mpni_io_send(dst, pkthead, data_size);
		{
			__MPI_buffer_header* next = bufhead->next;
			mpi_local_delete_buffer(bufhead);
			bufhead = next;
		}
	}
	sendqueue_head = bufhead;
	if (bufhead == NULL) sendqueue_tail = NULL;
}

static void mpi_local_receive_to_recvhash(void)
{
	/* fetch all buffered packets in the network interface */
#ifdef __MPI_LOCAL_DEBUG
	fprintf(stderr,
			"mpi_local_receive_to_recvhash(pu%d): "
			"fetch all buffered packets in the network interface...\n",
			get_puid());
	fflush(stderr);
#endif /* __MPI_LOCAL_DEBUG */
	while (1) {
		size_t src_tmp, size_tmp, key;
		size_t buffered;
		__MPI_buffer_header* bufhead;
		__MPI_packet_header* pkthead;
		buffered = __mpni_io_sense_receive_queue(&src_tmp, &size_tmp);
		if (buffered == 0) break;
		bufhead = mpi_local_new_buffer();
		pkthead = (__MPI_packet_header*)
				  ((void*)bufhead + sizeof(__MPI_buffer_header));
		bufhead->address = src_tmp;
		bufhead->size = size_tmp;
		__mpni_io_receive(pkthead);
#ifdef __MPI_LOCAL_DEBUG
		fprintf(stderr,
				"mpi_local_receive_to_recvhash(pu%d): "
				"received packet is (src:%d,size:%d,idx:%d/%d,tag:%d)\n",
				get_puid(), bufhead->address, bufhead->size, pkthead->index,
				pkthead->total, pkthead->tag);
		fflush(stderr);
#endif /* __MPI_LOCAL_DEBUG */
		key = (size_t)((src_tmp << 2) ^ pkthead->tag) % hashsize;
		bufhead->next = recvhash[key];
		recvhash[key] = bufhead;
	}
}

int mpi_local_fetch_from_recvhash
	(void* buf, int src, int tag, int key, int total_packet_number,
	 int remainder, int *return_src, int *return_tag)
{
	int fetched_count = 0;
#ifdef __MPI_LOCAL_DEBUG
	fprintf(stderr,
			"mpi_local_fetch_from_recvhash(pu%d): "
			"(source:%d tag:%d key:%d), total is %d, remainder is %d\n",
			get_puid(), src, tag, key, total_packet_number, remainder);
	fflush(stderr);
#endif /* __MPI_LOCAL_DEBUG */
	while (fetched_count < remainder) {
		__MPI_buffer_header* curr = recvhash[key];
		__MPI_buffer_header* prev = NULL;
		if (curr == NULL) {
#ifdef __MPI_LOCAL_DEBUG
			fprintf(stderr,
					"mpi_local_fetch_from_recvhash(pu%d): search failed\n",
					get_puid());
			fflush(stderr);
#endif /* __MPI_LOCAL_DEBUG */
			break;
		}
		while (curr != NULL) {
			__MPI_buffer_header* bufhead;
			__MPI_packet_header* pkthead;
			bufhead = curr;
			pkthead = (__MPI_packet_header*)
					  ((void*)curr + sizeof(__MPI_buffer_header));
			if ((return_src != NULL || bufhead->address == src) ||
				(return_tag != NULL || pkthead->tag == tag) ||
				pkthead->total == total_packet_number) {
#ifdef __MPI_LOCAL_DEBUG
				fprintf(stderr,
						"mpi_local_fetch_from_recvhash(pu%d): packet found, "
						"(src:%d,size:%d,idx:%d/%d,tag:%d)\n",
						get_puid(), bufhead->address, bufhead->size,
						pkthead->index, pkthead->total, pkthead->tag);
				fflush(stderr);
#endif /* __MPI_LOCAL_DEBUG */
				memcpy(buf + pkthead->index * max_data_size,
					   (void*)pkthead + sizeof(__MPI_packet_header),
					   pkthead->size);
				if (return_src != NULL) {
					*return_src = bufhead->address;
					return_src = NULL;
				}
				if (return_tag != NULL) {
					*return_tag = pkthead->tag;
					return_tag = NULL;
				}
				{
					__MPI_buffer_header* next = bufhead->next;
					if (curr == recvhash[key]) {
						recvhash[key] = next;
					} else {
						prev->next = next;
					}
					mpi_local_delete_buffer(curr);
					curr = next;
				}
				fetched_count++;
				if (fetched_count == remainder) break;
			} else {
				prev = curr;
				curr = bufhead->next;
			}
		}
	}
	return fetched_count;
}

size_t __mpi_local_get_sizeof_datatype(MPI_Datatype t)
{
	int sizeof_dtype;
	switch (t) {
	case MPI_BYTE:
	case MPI_CHAR:
	case MPI_UNSIGNED_CHAR:
		sizeof_dtype = 1; break;
	case MPI_SHORT:
	case MPI_UNSIGNED_SHORT:
		sizeof_dtype = sizeof(short); break;
	case MPI_INT:
	case MPI_UNSIGNED:
		sizeof_dtype = sizeof(int); break;
	case MPI_LONG:
	case MPI_UNSIGNED_LONG:
		sizeof_dtype = sizeof(long); break;
	case MPI_FLOAT:
		sizeof_dtype = sizeof(float); break;
	case MPI_DOUBLE:
		sizeof_dtype = sizeof(double); break;
	default:
		sizeof_dtype = 0;
		break;
	}
	return sizeof_dtype;
}

int __mpi_local_init(void)
{
	size_t i;
	if (mpi_local_init_flag) return 1;
	if (!__mpni_io_init()) return 0;
	max_packet_size = __mpni_io_get_max_packet_size();
	max_data_size = max_packet_size - sizeof(__MPI_packet_header);
	freelist = NULL;
	sendqueue_head = sendqueue_tail = NULL;
	recvhash = (__MPI_buffer_header**)malloc(hashsize * sizeof(*recvhash));
	for (i = 0; i < hashsize; i++) recvhash[i] = NULL;
#ifdef __MPI_LOCAL_DEBUG
	fprintf(stderr,
			"__mpi_local_init(pu%d): max_packet_size:%d max_data_size:%d\n",
			get_puid(), max_packet_size, max_data_size);
	fflush(stderr);
#endif /* __MPI_LOCAL_DEBUG */
	mpi_local_init_flag = 1;
	return 1;
}

void __mpi_local_finalize(void)
{
	__MPI_buffer_header* p;
	int i;
	if (!mpi_local_init_flag) return;
#ifdef __MPI_LOCAL_DEBUG
	fprintf(stderr, "__mpi_local_finalize(pu%d): free all buffers...\n",
			get_puid());
	fflush(stderr);
#endif /* __MPI_LOCAL_DEBUG */
	/* free send buffer(s) */
	p = sendqueue_head;
	while (p != NULL) {
		__MPI_buffer_header* next = p->next;
		free(p);
		p = next;
	}
	/* free receive buffer(s) */
	for (i = 0; i < hashsize; i++) {
		p = recvhash[i];
		while (p != NULL) {
			__MPI_buffer_header* next = p->next;
			free(p);
			p = next;
		}
	}
	free(recvhash);
	/* free freelist */
	p = freelist;
	while (p != NULL) {
		__MPI_buffer_header* next = p->next;
		free(p);
		p = next;
	}
#ifdef __MPI_LOCAL_DEBUG
	fprintf(stderr, "__mpi_local_finalize(pu%d): done.\n", get_puid());
	fflush(stderr);
#endif /* __MPI_LOCAL_DEBUG */
}

int __mpi_local_comm_rank(void)
{
	return get_puid();
}

int __mpi_local_comm_size(void)
{
	return get_punum();
}

__MPI_return_code __mpi_local_send(void* buf, int size, int dst, int tag)
{
#ifdef __MPI_LOCAL_DEBUG
	fprintf(stderr,
			"__mpi_local_send(pu%d): (size:%d destination:%d tag:%d)\n",
			get_puid(), size, dst, tag);
	fflush(stderr);
#endif /* __MPI_LOCAL_DEBUG */
	/* enqueue packet to sendqueue */
	mpi_local_enqueue_to_sendqueue(buf, size, dst, tag);
	/* send all packets enqueued in sendqueue, if available */
	mpi_local_send_from_sendqueue();
	return MPI_SUCCESS;
}

__MPI_return_code __mpi_local_receive
	(void* buf, int size, int src, int tag, MPI_Status *status)
{
	size_t total_packet_number, remainder;
	size_t target_key;
	int any_src_flag = (src == MPI_ANY_SOURCE),
		any_tag_flag = (tag == MPI_ANY_TAG);
	if (status != NULL) {
		if (!any_src_flag) status->MPI_SOURCE = src;
		if (!any_tag_flag) status->MPI_TAG = tag;
	}
	total_packet_number = (size + max_data_size - 1) / max_data_size;
	if (!any_src_flag && !any_tag_flag) {
		target_key = ((size_t)((src << 2) ^ tag) % hashsize);
	} else {
		target_key = 0;
	}
#ifdef __MPI_LOCAL_DEBUG
	fprintf(stderr,
			"__mpi_local_receive(pu%d): "
			"(size:%d source:%d tag:%d), total packet number is %d\n",
			get_puid(), size, src, tag, total_packet_number);
	fflush(stderr);
#endif /* __MPI_LOCAL_DEBUG */
	remainder = total_packet_number;
	while (remainder > 0) {
		int return_src, return_tag, tmp;
		/* send all buffered packets in sendqueue if available */
		mpi_local_send_from_sendqueue();
		/* receive all buffered packets in the network interface */
		mpi_local_receive_to_recvhash();
		/* fetch all buffered packets in recvhash */
		tmp = mpi_local_fetch_from_recvhash
				(buf, src, tag, target_key, total_packet_number,
				 remainder,
				 (any_src_flag ? &return_src : NULL),
				 (any_tag_flag ? &return_tag : NULL));
		if (tmp > 0) {
			remainder -= tmp;
			if (any_src_flag) {
				status->MPI_SOURCE = return_src;
				any_src_flag = 0;
			}
			if (any_tag_flag) {
				status->MPI_TAG = return_tag;
				any_tag_flag = 0;
			}
		} else {
			if (any_src_flag || any_tag_flag) {
				target_key++;
				if (target_key == hashsize) target_key = 0;
			}
		}
	}
	/* send all buffered packets in sendqueue */
	while (sendqueue_head != NULL) {
		mpi_local_send_from_sendqueue();
	}
	if (status != NULL) {
		status->MPI_ERROR = (int)MPI_SUCCESS;
		status->sizeof_message = size;
	}
	return MPI_SUCCESS;
}
