/*
 * <<< mpi_local.c >>>
 *
 * --- OSIRIS implementation of MPI subset 'mpi_local.c'
 *     Copyright (C) 2000 Amano Lab., Keio University. ---
 *
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 2 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License along
 *  with this program; if not, write to the Free Software Foundation, Inc.,
 *  59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
 */

/* #define __MPI_LOCAL_DEBUG */

#include <sys/types.h>
#include "mpi.h"
#include "mpi_local.h"
#include "mpni_io_funcs.h"
#ifdef __MPI_LOCAL_DEBUG
# include <stdio.h>
#endif /* __MPI_LOCAL_DEBUG */

static void* mpi_local_new_buffer(void);
static void mpi_local_delete_buffer(void*);
#ifdef __MPI_LOCAL_DEBUG
static void mpi_local_show_buffer(const char*, void*);
#endif /* __MPI_LOCAL_DEBUG */

const size_t hashsize = 4093;
static int mpi_local_init_flag = 0;
static size_t max_packet_size, max_data_size;
static void* sendbuf;
static void** recvhash;
static void* freelist;

static void* mpi_local_new_buffer(void)
{
	void* p;
	if (freelist == NULL) {
		p = (void*)malloc(max_packet_size + sizeof(__MPI_buffer_header));
	} else {
		p = freelist;
		freelist = ((__MPI_buffer_header*)(p))->next;
	}
	return p;
}

static void mpi_local_delete_buffer(void* p)
{
	((__MPI_buffer_header*)(p))->next = freelist;
	freelist = p;
}

#ifdef __MPI_LOCAL_DEBUG
static void mpi_local_show_buffer(const char* str, void* p)
{
	__MPI_buffer_header* bufhead;
	__MPI_packet_header* pkthead;
	bufhead = (__MPI_buffer_header*)(p);
	pkthead = (__MPI_packet_header*)(p + sizeof(__MPI_buffer_header));
	fprintf(stderr,
			"%s(pu%d): received packet is "
			"(src:%d,size:%d,idx:%d/%d,tag:%d)\n",
			str, get_puid(), bufhead->source, bufhead->size,
			pkthead->index, pkthead->total, pkthead->tag);
	fflush(stderr);
}
#endif /* __MPI_LOCAL_DEBUG */

int __mpi_local_init(void)
{
	size_t i;
	if (mpi_local_init_flag) return 1;
	if (!__mpni_io_init()) return 0;
	max_packet_size = __mpni_io_get_max_packet_size();
	max_data_size = max_packet_size - sizeof(__MPI_packet_header);
	freelist = NULL;
	sendbuf = (void*)malloc(max_packet_size);
	recvhash = (void**)malloc(hashsize * sizeof(*recvhash));
	for (i = 0; i < hashsize; i++) recvhash[i] = NULL;
#ifdef __MPI_LOCAL_DEBUG
	fprintf(stderr,
			"__mpi_local_init(pu%d): max_packet_size:%d max_data_size:%d\n",
			get_puid(), max_packet_size, max_data_size);
	fflush(stderr);
#endif /* __MPI_LOCAL_DEBUG */
	mpi_local_init_flag = 1;
	return 1;
}

void __mpi_local_finalize(void)
{
	void* p;
	int i;
	if (!mpi_local_init_flag) return;
	/* free send buffer(s) */
	free(sendbuf);
	/* free receive buffer(s) */
	for (i = 0; i < hashsize; i++) {
		p = recvhash[i];
		while (p != NULL) {
			void* next = ((__MPI_buffer_header*)(p))->next;
			free(p);
			p = next;
		}
	}
	free(recvhash);
	/* free freelist */
	p = freelist;
	while (p != NULL) {
		void* next = ((__MPI_buffer_header*)(p))->next;
		free(p);
		p = next;
	}
}

int __mpi_local_comm_rank(void)
{
	return get_puid();
}

int __mpi_local_comm_size(void)
{
	return get_punum();
}

__MPI_return_code __mpi_local_send
	(void* buf, int count, MPI_Datatype dtype, int dst, int tag)
{
	size_t sizeof_dtype, total_data_size, total_packet_number, i;
	switch (dtype) {
	case MPI_BYTE:
	case MPI_CHAR:
	case MPI_UNSIGNED_CHAR:
		sizeof_dtype = 1; break;
	case MPI_SHORT:
	case MPI_UNSIGNED_SHORT:
		sizeof_dtype = sizeof(short); break;
	case MPI_INT:
	case MPI_UNSIGNED:
		sizeof_dtype = sizeof(int); break;
	case MPI_LONG:
	case MPI_UNSIGNED_LONG:
		sizeof_dtype = sizeof(long); break;
	case MPI_FLOAT:
		sizeof_dtype = sizeof(float); break;
	case MPI_DOUBLE:
		sizeof_dtype = sizeof(double); break;
	}
	total_data_size = count * sizeof_dtype;
	total_packet_number = (total_data_size + max_data_size - 1) / max_data_size;
#ifdef __MPI_LOCAL_DEBUG
	fprintf(stderr,
			"__mpi_local_send(pu%d): total_data_size is %d bytes, "
			"total packet number is %d\n",
			get_puid(), total_data_size, total_packet_number);
#endif /* __MPI_LOCAL_DEBUG */
	for (i = 0; i < total_packet_number; i++) {
		size_t data_size;
		if (i < total_packet_number - 1) {
			data_size = max_data_size;
		} else {
			data_size = total_data_size % max_data_size;
			if (data_size == 0) data_size = max_data_size;
		}
		((__MPI_packet_header*)(sendbuf))->total
			= (size_t)(total_packet_number);
		((__MPI_packet_header*)(sendbuf))->index = i;
		((__MPI_packet_header*)(sendbuf))->size = data_size;
		((__MPI_packet_header*)(sendbuf))->tag = tag;
		memcpy(sendbuf + sizeof(__MPI_packet_header),
			   (char*)(buf) + i * max_data_size, data_size);
		data_size = (sizeof(__MPI_packet_header) + data_size + 15)
					& ~(size_t)(0xf);
		while (__mpni_io_sense_send_queue() == 0) {
			size_t j;
			for (j = 0; j < 16; j++);
		}
		__mpni_io_send((size_t)(dst), sendbuf, data_size);
	}
#ifdef __MPI_LOCAL_DEBUG
	fflush(stderr);
#endif /* __MPI_LOCAL_DEBUG */
	return MPI_SUCCESS;
}

__MPI_return_code __mpi_local_receive
	(void* buf, int count, MPI_Datatype dtype, int src, int tag,
	 MPI_Status *status)
{
	/* fetch all buffered packets in network interface */
#ifdef __MPI_LOCAL_DEBUG
	fprintf(stderr,
			"__mpi_local_receive(pu%d): "
			"fetch all buffered packets in network interface...\n",
			get_puid());
#endif /* __MPI_LOCAL_DEBUG */
	while (1) {
		size_t src_tmp, size_tmp, key;
		size_t buffered;
		void* p;
		__MPI_buffer_header* bufhead;
		__MPI_packet_header* pkthead;
		buffered = __mpni_io_sense_receive_queue(&src_tmp, &size_tmp);
		if (buffered == 0) break;
		p = mpi_local_new_buffer();
		bufhead = (__MPI_buffer_header*)(p);
		pkthead = (__MPI_packet_header*)(p + sizeof(__MPI_buffer_header));
		bufhead->source = src_tmp;
		bufhead->size = size_tmp;
		__mpni_io_receive(pkthead);
#ifdef __MPI_LOCAL_DEBUG
		mpi_local_show_buffer("__mpi_local_receive", p);
#endif /* __MPI_LOCAL_DEBUG */
		key = (size_t)((src_tmp << 2) ^ pkthead->tag) % hashsize;
		bufhead->next = recvhash[key];
		recvhash[key] = p;
	}
	{
		size_t sizeof_dtype, total_data_size, total_packet_number, remainder;
		size_t target_key;
		void* curr;
		void* prev;
		switch (dtype) {
		case MPI_BYTE:
		case MPI_CHAR:
		case MPI_UNSIGNED_CHAR:
			sizeof_dtype = 1; break;
		case MPI_SHORT:
		case MPI_UNSIGNED_SHORT:
			sizeof_dtype = sizeof(short); break;
		case MPI_INT:
		case MPI_UNSIGNED:
			sizeof_dtype = sizeof(int); break;
		case MPI_LONG:
		case MPI_UNSIGNED_LONG:
			sizeof_dtype = sizeof(long); break;
		case MPI_FLOAT:
			sizeof_dtype = sizeof(float); break;
		case MPI_DOUBLE:
			sizeof_dtype = sizeof(double); break;
		}
		total_data_size = count * sizeof_dtype;
		total_packet_number
			= (total_data_size + max_data_size - 1) / max_data_size;
		remainder = total_packet_number;
		target_key = (size_t)((src << 2) ^ tag) % hashsize;
		curr = recvhash[target_key];
#ifdef __MPI_LOCAL_DEBUG
		fprintf(stderr,
				"__mpi_local_receive(pu%d): total_data_size is %d bytes, "
				"total packet number is %d\n",
				get_puid(), total_data_size, total_packet_number);
#endif /* __MPI_LOCAL_DEBUG */
		while (remainder > 0) {
			__MPI_buffer_header* bufhead;
			__MPI_packet_header* pkthead;
			if (curr == NULL) {
				while (1) {
					size_t src_tmp, size_tmp, key;
#ifdef __MPI_LOCAL_DEBUG
				fprintf(stderr,
						"__mpi_local_receive(pu%d): search failed, "
						"wait for packet(s)...\n", get_puid());
#endif /* __MPI_LOCAL_DEBUG */
					while (__mpni_io_sense_receive_queue(&src_tmp, &size_tmp)
						   == 0) {
						size_t i;
						for (i = 0; i < 16; i++);
					}
					curr = mpi_local_new_buffer();
					bufhead = (__MPI_buffer_header*)(curr);
					pkthead = (__MPI_packet_header*)(curr
						+ sizeof(__MPI_buffer_header));
					bufhead->source = src_tmp;
					bufhead->size = size_tmp;
					__mpni_io_receive(pkthead);
#ifdef __MPI_LOCAL_DEBUG
					mpi_local_show_buffer("__mpi_local_receive", curr);
#endif /* __MPI_LOCAL_DEBUG */
					key = (size_t)((src_tmp << 2) ^ pkthead->tag) % hashsize;
					bufhead->next = recvhash[key];
					recvhash[key] = curr;
					if (bufhead->source == src &&
						pkthead->tag == tag &&
						pkthead->total == total_packet_number) break;
				}
			}
			bufhead = (__MPI_buffer_header*)(curr);
			pkthead = (__MPI_packet_header*)(curr
				+ sizeof(__MPI_buffer_header));
			if (bufhead->source != src ||
				pkthead->tag != tag ||
				pkthead->total != total_packet_number) {
				prev = curr;
				curr = bufhead->next;
				continue;
			}
#ifdef __MPI_LOCAL_DEBUG
			fprintf(stderr,
					"__mpi_local_receive(pu%d): packet found, "
					"(src:%d,size:%d,idx:%d/%d,tag:%d)\n",
					get_puid(), bufhead->source, bufhead->size,
					pkthead->index, pkthead->total, pkthead->tag);
#endif /* __MPI_LOCAL_DEBUG */
			memcpy(buf + pkthead->index * max_data_size,
				   curr + sizeof(__MPI_buffer_header)
						+ sizeof(__MPI_packet_header), pkthead->size);
			if (curr == recvhash[target_key]) {
				void* next = bufhead->next;
				recvhash[target_key] = next;
				bufhead->next = freelist;
				freelist = curr;
				curr = next;
			} else {
				void* next = bufhead->next;
				((__MPI_buffer_header*)(prev))->next = next;
				bufhead->next = freelist;
				freelist = curr;
				curr = next;
			}
			remainder--;
		}
	}
	if (status != NULL) {
		/* sorry, not implemented */
	}
#ifdef __MPI_LOCAL_DEBUG
	fflush(stderr);
#endif /* __MPI_LOCAL_DEBUG */
	return MPI_SUCCESS;
}
