/*
 * <<< mp_bitonic.c >>>
 *
 * --- Sample application for isis 'bitonic sort' - for MPI
 *     Copyright (C) 2000-2001 Amano Lab., Keio University. ---
 *
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 2 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License along
 *  with this program; if not, write to the Free Software Foundation, Inc.,
 *  59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
 */

#include <ctype.h>
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <mpi.h>

#define VALUEBIT				16
#define VALUEMASK				(~(~0 << VALUEBIT))
#define RADIXBIT				(VALUEBIT >> 1)
#define DEFAULT_SIZE			64
#define DEFAULT_RADIX_THRESHOLD	16
#define MAX_PRINT_SIZE			256

typedef unsigned data_t;

#define rnd() ((unsigned)((seed = 1566083941UL * seed + 1) >> 16))
#define srnd(x) (seed = (x))
static unsigned long seed = 1;

static int fix_number_to_exp2n(int);
static void init(data_t*, int);
static void show(data_t*, int);
static void show_raw(data_t*, int);
static int check(data_t*, int);
static void radix_sort(data_t*, data_t*, int);
static void radix_rev_sort(data_t*, data_t*, int);
static void bitonic_split(data_t*, data_t*, int);
static void bitonic_rev_split(data_t*, data_t*, int);
static void bitonic_to_sequence(data_t*, data_t*, int);
static void bitonic_to_rev_sequence(data_t*, data_t*, int);
void bitonic_sort(data_t*, data_t*, int, int, int, int, int);

int fix_number_to_exp2n(int x)
{
	int i = 0;
	while (x != 0) {
		x >>= 1;
		i++;
	}
	return (i == 0) ? 0 : (1 << (i - 1));
}

void init(data_t *a, int size)
{
	int	i;
	for (i = 0; i < size; i++) {
		unsigned long rnd_num = ((unsigned long)rnd() << 16) | rnd();
		a[i] = ((rnd_num >> (32 - VALUEBIT)) & VALUEMASK);
	}
}

void show(data_t *a, int size)
{
	int i;
	for (i = 0; ; i++) {
		printf("%0*x", ((VALUEBIT + 3) / 4), a[i]);
		if (i == size - 1) break;
		putchar((i % 16 != 15) ? '.' : '\n');
	}
	putchar('\n');
}

void show_raw(data_t *a, int size)
{
	int i;
	for (i = 0; ; i++) {
		printf("%0*x", ((VALUEBIT + 3) / 4), a[i]);
		if (i == size - 1) break;
		putchar('.');
	}
}

int check(data_t *a, int size)
{
	int i;
	for (i = 1; i < size; i++) {
		if (a[i] < a[i - 1]) return 0;
	}
	return 1;
}

void radix_sort(data_t *a, data_t *w, int n)
{
	static int c[1 << RADIXBIT];
	const int radix = (1 << RADIXBIT), radixmask = radix - 1;
	int i, j;
#ifdef DEBUG
	printf("%02d/%02d radix_sort(n:%d):\n", get_puid(), get_punum(), n);
	printf("%02d/%02d < ", get_puid(), get_punum());
	show_raw(a, n); printf("\n");
#endif /* DEBUG */
	for (i = 0; i < radix; i++) c[i] = 0;
	for (i = 0; i < n; i++) c[a[i] & radixmask]++;
	for (j = c[0], i = 1; i < radix; i++) j = (c[i] += j);
	for (i = n - 1; i >= 0; i--) w[--c[a[i] & radixmask]] = a[i];
	for (i = 0; i < radix; i++) c[i] = 0;
	for (i = 0; i < n; i++) c[w[i] >> RADIXBIT]++;
	for (j = c[0], i = 1; i < radix; i++) j = (c[i] += j);
	for (i = n - 1; i >= 0; i--) a[--c[w[i] >> RADIXBIT]] = w[i];
#ifdef DEBUG
	printf("%02d/%02d > ", get_puid(), get_punum());
	show_raw(a, n); printf("\n");
#endif /* DEBUG */
}

void radix_rev_sort(data_t *a, data_t *w, int n)
{
	static int c[1 << RADIXBIT];
	const int radix = (1 << RADIXBIT), radixmask = radix - 1;
	int i, j;
#ifdef DEBUG
	printf("%02d/%02d radix_rev_sort(n:%d):\n", get_puid(), get_punum(), n);
	printf("%02d/%02d < ", get_puid(), get_punum());
	show_raw(a, n); printf("\n");
#endif /* DEBUG */
	for (i = 0; i < radix; i++) c[i] = 0;
	for (i = 0; i < n; i++) c[a[i] & radixmask]++;
	for (j = c[radix - 1], i = radix - 2; i >= 0; i--) j = (c[i] += j);
	for (i = n - 1; i >= 0; i--) w[--c[a[i] & radixmask]] = a[i];
	for (i = 0; i < radix; i++) c[i] = 0;
	for (i = 0; i < n; i++) c[w[i] >> RADIXBIT]++;
	for (j = c[radix - 1], i = radix - 2; i >= 0; i--) j = (c[i] += j);
	for (i = n - 1; i >= 0; i--) a[--c[w[i] >> RADIXBIT]] = w[i];
#ifdef DEBUG
	printf("%02d/%02d > ", get_puid(), get_punum());
	show_raw(a, n); printf("\n");
#endif /* DEBUG */
}

void bitonic_split(data_t *a, data_t *b, int n)
{
	int i;
#ifdef DEBUG
	printf("%02d/%02d bitonic_split(n:%d):\n", get_puid(), get_punum(), n);
	printf("%02d/%02d < ", get_puid(), get_punum());
	show_raw(a, n); printf(" : "); show_raw(b, n); printf("\n");
#endif /* DEBUG */
	for (i = 0; i < n; i++) {
		if (a[i] > b[i]) {
			data_t tmp;
			tmp = a[i]; a[i] = b[i]; b[i] = tmp;
		}
	}
#ifdef DEBUG
	printf("%02d/%02d > ", get_puid(), get_punum());
	show_raw(a, n); printf(" : "); show_raw(b, n); printf("\n");
#endif /* DEBUG */
}

void bitonic_rev_split(data_t *a, data_t *b, int n)
{
	int i;
#ifdef DEBUG
	printf("%02d/%02d bitonic_rev_split(n:%d):\n", get_puid(), get_punum(), n);
	printf("%02d/%02d < ", get_puid(), get_punum());
	show_raw(a, n); printf(" : "); show_raw(b, n); printf("\n");
#endif /* DEBUG */
	for (i = 0; i < n; i++) {
		if (a[i] < b[i]) {
			data_t tmp;
			tmp = a[i]; a[i] = b[i]; b[i] = tmp;
		}
	}
#ifdef DEBUG
	printf("%02d/%02d > ", get_puid(), get_punum());
	show_raw(a, n); printf(" : "); show_raw(b, n); printf("\n");
#endif /* DEBUG */
}

void bitonic_to_sequence(data_t *a, data_t *w, int n)
{
	int i, j, k, inc_flag;
#ifdef DEBUG
	printf("%02d/%02d bitonic_to_sequence(n:%d):\n", get_puid(), get_punum(),
		   n);
	printf("%02d/%02d < ", get_puid(), get_punum());
	show_raw(a, n); printf("\n");
#endif /* DEBUG */
	/* search maximum and minimum point */
	for (i = 0; i < n - 2; i++) {
		if (a[i] < a[i + 1]) {
			inc_flag = 1;
			break;
		} else if (a[i] > a[i + 1]) {
			inc_flag = 0;
			break;
		}
	}
	if (i == n - 2) return;
	if (inc_flag) {
		/* first direction is increase */
		while (a[i] <= a[i + 1]) {
			i++;
			if (i == n - 1) return;
		}
		j = 0; while (i < n) { w[j] = a[i]; i++, j++; }
		i = 0; while (j < n) { w[j] = a[i]; i++, j++; }
		i = 0, j = n - 1;
		for (k = n - 1; k >= 0; k--) {
			if (w[i] >= w[j]) {
				a[k] = w[i]; i++;
			} else {
				a[k] = w[j]; j--;
			}
		}
	} else {
		/* first direction is decrease */
		while (a[i] >= a[i + 1]) {
			i++;
			if (i == n - 1) break;
		}
		j = 0; while (i < n) { w[j] = a[i]; i++, j++; }
		i = 0; while (j < n) { w[j] = a[i]; i++, j++; }
		i = 0, j = n - 1;
		for (k = 0; k < n; k++) {
			if (w[i] <= w[j]) {
				a[k] = w[i]; i++;
			} else {
				a[k] = w[j]; j--;
			}
		}
	}
#ifdef DEBUG
	printf("%02d/%02d > ", get_puid(), get_punum());
	show_raw(a, n); printf("\n");
#endif /* DEBUG */
}

void bitonic_to_rev_sequence(data_t *a, data_t *w, int n)
{
	int i, j, k, inc_flag;
#ifdef DEBUG
	printf("%02d/%02d bitonic_to_rev_sequence(n:%d):\n",
		   get_puid(), get_punum(), n);
	printf("%02d/%02d < ", get_puid(), get_punum());
	show_raw(a, n); printf("\n");
#endif /* DEBUG */
	/* search maximum and minimum point */
	for (i = 0; i < n - 2; i++) {
		if (a[i] < a[i + 1]) {
			inc_flag = 1;
			break;
		} else if (a[i] > a[i + 1]) {
			inc_flag = 0;
			break;
		}
	}
	if (i == n - 2) return;
	if (inc_flag) {
		/* first direction is increase */
		while (a[i] <= a[i + 1]) {
			i++;
			if (i == n - 1) break;
		}
		j = 0; while (i < n) { w[j] = a[i]; i++, j++; }
		i = 0; while (j < n) { w[j] = a[i]; i++, j++; }
		i = 0, j = n - 1;
		for (k = 0; k < n; k++) {
			if (w[i] >= w[j]) {
				a[k] = w[i]; i++;
			} else {
				a[k] = w[j]; j--;
			}
		}
	} else {
		/* first direction is decrease */
		while (a[i] >= a[i + 1]) {
			i++;
			if (i == n - 1) return;
		}
		j = 0; while (i < n) { w[j] = a[i]; i++, j++; }
		i = 0; while (j < n) { w[j] = a[i]; i++, j++; }
		i = 0, j = n - 1;
		for (k = n - 1; k >= 0; k--) {
			if (w[i] <= w[j]) {
				a[k] = w[i]; i++;
			} else {
				a[k] = w[j]; j--;
			}
		}
	}
#ifdef DEBUG
	printf("%02d/%02d > ", get_puid(), get_punum());
	show_raw(a, n); printf("\n");
#endif /* DEBUG */
}

int main(int argc, char **argv)
{
	clock_t init_start_time, init_end_time, calc_start_time, calc_end_time;
	data_t *a;
	data_t *w;
	size_t size = DEFAULT_SIZE, radix_threshold = DEFAULT_RADIX_THRESHOLD,
		   punum, puid, n;
	int check_flag = 0, info_flag = 0, verbose_flag = 0, result;

	/* initialize MPI environment and random seed */
	MPI_Init(&argc, &argv);
	MPI_Comm_size(MPI_COMM_WORLD, &punum);
	MPI_Comm_rank(MPI_COMM_WORLD, &puid);
	punum = fix_number_to_exp2n(punum);
	if (puid >= punum) return 0;
	srnd(puid + 1); rnd();

	/* read arguments */
	while (*++argv != NULL) {
		if (**argv == '-') {
			switch (*++*argv) {
			case 'i':
				info_flag = 1;
				break;
			case 'r':
				radix_threshold = atoi(*argv + 1);
				if (radix_threshold < 0) {
					radix_threshold = DEFAULT_RADIX_THRESHOLD;
				}
				break;
			case 't':
				check_flag = 1;
				break;
			case 'v':
				verbose_flag = 1;
				break;
			default:
				break;
			}
		} else if (isdigit((int)**argv)) {
			size = atoi(*argv);
			if (size < 0) size = DEFAULT_SIZE;
		}
	}
	size = fix_number_to_exp2n(size);
	if (size < 4) size = 4;
	if (punum > size / 4) {
		if (puid == 0) fprintf(stderr, "Too many processors.\n");
		exit(1);
	}
	n = size / punum;
	radix_threshold = fix_number_to_exp2n(radix_threshold);
	if (radix_threshold > n) radix_threshold = n;
	if (radix_threshold > size / 2) radix_threshold = size / 2;
	if (radix_threshold < 4) radix_threshold = 4;
	if (verbose_flag && puid == 0) {
		printf("size:%d radix-threshold:%d value:0-%#x\n",
			   size, radix_threshold, VALUEMASK);
	}
	init_start_time = init_end_time = calc_start_time = calc_end_time = 0;

	/* initialize */
	if (info_flag) init_start_time = clock();
	a = (data_t*)malloc(n * sizeof(data_t));
	w = (data_t*)malloc(n * sizeof(data_t));
	if (a == NULL || w == NULL) {
		fputs("Out of memory.\n", stderr);
		exit(1);
	}
#ifndef DEBUG
	init(a, n);
#else
	/* simple init data: for debug only */
	{
		int i;
		for (i = 0; i < n; i++) a[i] = (i + puid * n) & VALUEMASK;
	}
#endif /* DEBUG */
	if (verbose_flag && size <= MAX_PRINT_SIZE) {
		printf("%02d: ", puid);
		show_raw(a, n);
		printf("\n");
	}
	if (info_flag) init_end_time = clock();

	/* calculate */
	if (info_flag) calc_start_time = clock();
	bitonic_sort(a, w, size, n, radix_threshold, punum, puid);
	if (info_flag) calc_end_time = clock();

	/* check */
	if (check_flag) {
		int value[2];
		result = check(a, n);
		if (puid == 0) {
			MPI_Status dummy;
			int max = a[n - 1], i;
			for (i = 1; i < punum; i++) {
				MPI_Recv(value, 2, MPI_INT, i, i, MPI_COMM_WORLD, &dummy);
				if (max > value[0] || value[0] > value[1]) {
					result = 0;
				}
				max = value[1];
			}
			if (verbose_flag) {
				puts(result ? "success." : "failed.");
			}
		} else {
			if (result) {
				value[0] = a[0];
				value[1] = a[n - 1];
			} else {
				value[0] = 1, value[1] = 0;
			}
			MPI_Send(value, 2, MPI_INT, 0, puid, MPI_COMM_WORLD);
		}
	} else {
		result = 1;
	}

	/* show results */
	if (verbose_flag && size <= MAX_PRINT_SIZE) {
		printf("%02d: ", puid);
		show_raw(a, n);
		printf("\n");
	}
	if (info_flag && puid == 0) {
		printf("init start: %10ld\n"
			   "init end:   %10ld\n"
			   "calc start: %10ld\n"
			   "calc end:   %10ld\n"
			   "init time:  %10ld\n"
			   "calc time:  %10ld\n",
			   (long)init_start_time, (long)init_end_time,
			   (long)calc_start_time, (long)calc_end_time,
			   (long)(init_end_time - init_start_time),
			   (long)(calc_end_time - calc_start_time));
	}

	/* MPI_Finalize(); */
	return (result ? 0 : 1);
}

void bitonic_sort(data_t *a, data_t *w, int size, int n, int radix_threshold,
				  int punum, int puid)
{
	int i, j;
	if (radix_threshold < n) {
		for (i = 0; i < n; i += (radix_threshold << 1)) {
			radix_sort(a + i, w, radix_threshold);
			radix_rev_sort(a + i + radix_threshold, w, radix_threshold);
		}
		for (i = (radix_threshold << 2); i <= n; i <<= 1) {
			const int i_div_2 = (i >> 1);
			int k, l;
			for (j = 0; j < n; j += i) {
				for (k = i_div_2; k > radix_threshold; k >>= 1) {
					const int k_div_2 = (k >> 1);
					for (l = 0; l < i_div_2; l += k) {
						const int k1 = j + l, k2 = i_div_2 + j + l;
						bitonic_split(a + k1, a + k1 + k_div_2, k_div_2);
						bitonic_rev_split(a + k2, a + k2 + k_div_2, k_div_2);
					}
				}
				for (k = 0; k < i_div_2; k += radix_threshold) {
					bitonic_to_sequence(a + j + k, w, radix_threshold);
					bitonic_to_rev_sequence(a + j + k + i_div_2, w,
											radix_threshold);
				}
			}
		}
		if ((puid & 1) == 0) {
			for (i = n; i > radix_threshold; i >>= 1) {
				const int i_div_2 = (i >> 1);
				for (j = 0; j < n; j += i) {
					bitonic_split(a + j, a + j + i_div_2, i_div_2);
				}
			}
			for (i = 0; i < n; i += radix_threshold) {
				bitonic_to_sequence(a + i, w, radix_threshold);
			}
		} else {
			for (i = n; i > radix_threshold; i >>= 1) {
				const int i_div_2 = (i >> 1);
				for (j = 0; j < n; j += i) {
					bitonic_rev_split(a + j, a + j + i_div_2, i_div_2);
				}
			}
			for (i = 0; i < n; i += radix_threshold) {
				bitonic_to_rev_sequence(a + i, w, radix_threshold);
			}
		}
	} else {
		if ((puid & 1) == 0) {
			radix_sort(a, w, n);
		} else {
			radix_rev_sort(a, w, n);
		}
	}
#ifdef DEBUG
	printf("%02d/%02d: ", puid, punum); show_raw(a, n); printf("\n");
#endif /* DEBUG */
	{
		const size_t comm_blocksize = (n >> 1);
		int punum_shift, tag;
		punum_shift = 0;
		i = punum;
		while (i > 1) {
			i >>= 1;
			punum_shift++;
		}
		tag = 0;
		for (i = 0; i < punum_shift; i++) {
			MPI_Status dummy;
			int partner, rev_flag;
			rev_flag = (((puid >> i) & 2) != 0);
			for (j = i; j >= 0; j--) {
				partner = (puid ^ (1 << j));
				if (puid < partner) {
#ifdef DEBUG
					printf("%02d/%02d comm with pu%02d:\n",
						   get_puid(), get_punum(), partner);
					printf("%02d/%02d < ", get_puid(), get_punum());
					show_raw(a, comm_blocksize); printf(" : ");
					show_raw(a + comm_blocksize, comm_blocksize); printf("\n");
#endif /* DEBUG */
					MPI_Send(a + comm_blocksize, comm_blocksize, MPI_INT,
							 partner, tag, MPI_COMM_WORLD);
					MPI_Recv(w, comm_blocksize, MPI_INT, partner,
							 tag, MPI_COMM_WORLD, &dummy);
#ifdef DEBUG
					printf("%02d/%02d > ", get_puid(), get_punum());
					show_raw(a, comm_blocksize); printf(" : ");
					show_raw(w, comm_blocksize); printf("\n");
#endif /* DEBUG */
					if (!rev_flag) {
						bitonic_split(a, w, comm_blocksize);
					} else {
						bitonic_rev_split(a, w, comm_blocksize);
					}
#ifdef DEBUG
					printf("%02d/%02d comm with pu%02d:\n",
						   get_puid(), get_punum(), partner);
					printf("%02d/%02d < ", get_puid(), get_punum());
					show_raw(a, comm_blocksize); printf(" : ");
					show_raw(w, comm_blocksize); printf("\n");
#endif /* DEBUG */
					MPI_Send(w, comm_blocksize, MPI_INT, partner,
							 tag + 1, MPI_COMM_WORLD);
					MPI_Recv(a + comm_blocksize, comm_blocksize, MPI_INT,
							 partner, tag + 1, MPI_COMM_WORLD, &dummy);
#ifdef DEBUG
					printf("%02d/%02d > ", get_puid(), get_punum());
					show_raw(a, comm_blocksize); printf(" : ");
					show_raw(a + comm_blocksize, comm_blocksize); printf("\n");
#endif /* DEBUG */
				} else {
#ifdef DEBUG
					printf("%02d/%02d comm with pu%02d:\n",
						   get_puid(), get_punum(), partner);
					printf("%02d/%02d < ", get_puid(), get_punum());
					show_raw(a, comm_blocksize); printf(" : ");
					show_raw(a + comm_blocksize, comm_blocksize); printf("\n");
#endif /* DEBUG */
					MPI_Send(a, comm_blocksize, MPI_INT, partner,
							 tag, MPI_COMM_WORLD);
					MPI_Recv(w, comm_blocksize, MPI_INT, partner,
							 tag, MPI_COMM_WORLD, &dummy);
#ifdef DEBUG
					printf("%02d/%02d > ", get_puid(), get_punum());
					show_raw(w, comm_blocksize); printf(" : ");
					show_raw(a + comm_blocksize, comm_blocksize); printf("\n");
#endif /* DEBUG */
					if (!rev_flag) {
						bitonic_rev_split(a + comm_blocksize, w,
										  comm_blocksize);
					} else {
						bitonic_split(a + comm_blocksize, w, comm_blocksize);
					}
#ifdef DEBUG
					printf("%02d/%02d comm with pu%02d:\n",
						   get_puid(), get_punum(), partner);
					printf("%02d/%02d < ", get_puid(), get_punum());
					show_raw(w, comm_blocksize); printf(" : ");
					show_raw(a + comm_blocksize, comm_blocksize); printf("\n");
#endif /* DEBUG */
					MPI_Send(w, comm_blocksize, MPI_INT, partner,
							 tag + 1, MPI_COMM_WORLD);
					MPI_Recv(a, comm_blocksize, MPI_INT, partner,
							 tag + 1, MPI_COMM_WORLD, &dummy);
#ifdef DEBUG
					printf("%02d/%02d > ", get_puid(), get_punum());
					show_raw(a, comm_blocksize); printf(" : ");
					show_raw(a + comm_blocksize, comm_blocksize); printf("\n");
#endif /* DEBUG */
				}
				tag += 2;
			}
			if (!rev_flag) {
				bitonic_to_sequence(a, w, n);
			} else {
				bitonic_to_rev_sequence(a, w, n);
			}
#ifdef DEBUG
			printf("%02d/%02d: ", puid, punum); show_raw(a, n); printf("\n");
#endif /* DEBUG */
		}
	}
}
