/*
 * <<< mp_fft.c >>>
 *
 * --- Sample application for isis 'fast fourier transform' - for MPI
 *     Copyright (C) 2000 Amano Lab., Keio University. ---
 *
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 2 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License along
 *  with this program; if not, write to the Free Software Foundation, Inc.,
 *  59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
 */

#include <ctype.h>
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <mpi.h>

#define PI 3.14159265358979323846

#define DEFAULT_SIZE	64
#define MAX_PRINT_SIZE	256
#define DIFF_LIMIT		1e-7

typedef double data_t;
typedef struct { data_t re, im; } complex;

static int fix_number_to_exp2n(int);
static int get_bitrev_number(int, int);
static void make_data(complex*, int, const complex*, int, int);
static void copy_data(const complex*, complex*, int);
static int check(const complex*, const complex*, int);
static void make_rotate_table(complex*, int);
void fft(complex*, complex*, int, const complex*, int, int);
void ufft(complex*, complex*, int, const complex*, int, int);

static int fix_number_to_exp2n(int x)
{
	int i = 0;
	while (x != 0) {
		x >>= 1;
		i++;
	}
	return (i == 0) ? 0 : (1 << (i - 1));
}

static int get_bitrev_number(int bit_number, int x)
{
	int y, i;
	y = 0;
	for (i = 0; i < bit_number; i++) {
		y = (y << 1) | (x & 1);
		x >>= 1;
	}
	return y;
}

static void make_data(complex *z, int n, const complex *rotate_tbl, int punum,
					  int puid)
{
	const int offset = n * puid;
	int size_div_4_shift; /* (1 << size_div_4_shift) == size / 4 */
	int size_div_4_mask;  /* (x & size_div_4_mask) == x % (size / 4) */
	int i;
	size_div_4_shift = 0;
	i = ((n * punum) >> 2);
	while (i > 1) {
		i >>= 1;
		size_div_4_shift++;
	}
	size_div_4_mask = (1 << size_div_4_shift) - 1;
	for (i = 0; i < n; i++) {
		const int upper = ((i + offset) >> size_div_4_shift);
		const int lower = ((i + offset) & size_div_4_mask);
		switch (upper) {
		case 0:
			z[i].re = rotate_tbl[lower].re;
			break;
		case 1:
			z[i].re = -rotate_tbl[lower].im;
			break;
		case 2:
			z[i].re = -rotate_tbl[lower].re;
			break;
		default: /* case 3: */
			z[i].re = rotate_tbl[lower].im;
			break;
		}
		z[i].im = 0;
	}
}

static void copy_data(const complex *src, complex *dst, int n)
{
	int i;
	for (i = 0; i < n; i++) {
		dst[i].re = src[i].re;
		dst[i].im = src[i].im;
	}
}

static int check(const complex *z1, const complex *z2, int n)
{
	data_t diff_sigma = 0;
	int i;
	for (i = 0; i < n; i++) {
		data_t dx, dy;
		dx = z1[i].re - z2[i].re;
		dy = z1[i].im - z2[i].im;
		diff_sigma = dx * dx + dy * dy;
	}
	return sqrt(diff_sigma) < DIFF_LIMIT;
}

static void make_rotate_table(complex *tbl, int n)
{
	int i;
	/*
	for (i = 0; i < n / 4; i++) {
		data_t t = 2 * PI * i / n;
		tbl[i].re = cos(t);
		tbl[i].im = sin(t);
	}
	*/
	data_t c, s, dc, ds, t;
	t = sin(PI / n);
	dc = 2 * t * t;
	ds = sqrt(dc * (2 - dc));
	t = 2 * dc;
	c = tbl[0].re = 1;
	s = tbl[0].im = 0;
	for (i = 1; i < n / 8; i++) {
		c -= dc;
		s += ds;
		dc += t * c;
		ds -= t * s;
		tbl[i].re = tbl[n / 4 - i].im = c;
		tbl[i].im = tbl[n / 4 - i].re = s;
	}
	if (n / 8 != 0) {
		tbl[n / 8].re = tbl[n / 8].im = sqrt(0.5);
	}
}

int main(int argc, char **argv)
{
	clock_t init_start_time, init_end_time, calc_start_time, calc_end_time;
	complex *z1;
	complex *z2;
	complex *z3;
	complex *tmp_buf;
	complex *rotate_tbl;
	size_t size = DEFAULT_SIZE, check_flag = 0, info_flag = 0, verbose_flag = 0;
	size_t punum, puid, n;

	/* initialize MPI environment */
	MPI_Init(&argc, &argv);
	MPI_Comm_size(MPI_COMM_WORLD, &punum);
	MPI_Comm_rank(MPI_COMM_WORLD, &puid);
	punum = fix_number_to_exp2n(punum);
	if (puid >= punum) return 0;

	/* read arguments */
	while (*++argv != NULL) {
		if (**argv == '-') {
			switch (*++*argv) {
			case 'i':
				info_flag = 1;
				break;
			case 't':
				check_flag = 1;
				break;
			case 'v':
				verbose_flag = 1;
				break;
			default:
				break;
			}
		} else if (isdigit((int)**argv)) {
			size = atoi(*argv);
			if (size < 0) size = DEFAULT_SIZE;
		}
	}
	size = fix_number_to_exp2n(size);
	if (size < 4) size = 4;
	if (punum > size / 2) {
		if (puid == 0) fprintf(stderr, "Too many processors.\n");
		exit(1);
	}
	n = size / punum;
	if (verbose_flag && puid == 0) {
		printf("size:%d punum:%d\n", size, punum);
	}
	init_start_time = init_end_time = calc_start_time = calc_end_time = 0;

	/* initialize */
	if (info_flag) init_start_time = clock();
	z1 = (complex*)malloc(n * sizeof(complex));
	tmp_buf = (complex*)malloc(n / 2 * sizeof(complex));
	rotate_tbl = (complex*)malloc(size / 4 * sizeof(complex));
	if (!check_flag) {
		z2 = z3 = NULL;
		if (z1 == NULL || tmp_buf == NULL || rotate_tbl == NULL) {
			fputs("Out of memory.\n", stderr);
			exit(1);
		}
	} else {
		z2 = (complex*)malloc(n * sizeof(complex));
		z3 = (complex*)malloc(n * sizeof(complex));
		if (z1 == NULL || z2 == NULL || z3 == NULL || tmp_buf == NULL ||
			rotate_tbl == NULL) {
			fputs("Out of memory.\n", stderr);
			exit(1);
		}
	}
	make_rotate_table(rotate_tbl, size);
	make_data(z1, n, rotate_tbl, punum, puid);
	if (info_flag) init_end_time = clock();

	/* calculate and check, show results */
	if (!check_flag) {
		if (info_flag) calc_start_time = clock();
		fft(z1, tmp_buf, n, rotate_tbl, punum, puid);
		if (info_flag) calc_end_time = clock();
		if (verbose_flag && size <= MAX_PRINT_SIZE) {
			const size_t offset = n * puid;
			size_t i;
			for (i = 0; i < n; i++) {
				printf("%03d: (% 5.3f, % 5.3f)\n",
					   i + offset, z1[i].re, z1[i].im);
			}
		}
	} else {
		if (info_flag) calc_start_time = clock();
		copy_data(z1, z2, n);
		fft(z2, tmp_buf, n, rotate_tbl, punum, puid);
		copy_data(z2, z3, n);
		ufft(z3, tmp_buf, n, rotate_tbl, punum, puid);
		if (info_flag) calc_end_time = clock();
		if (verbose_flag) {
			if (size <= MAX_PRINT_SIZE) {
				int i;
				for (i = 0; i < n; i++) {
					const int offset = n * puid;
					data_t dx, dy, d;
					dx = z1[i].re - z3[i].re;
					dy = z1[i].im - z3[i].im;
					d = sqrt(dx * dx + dy * dy);
					printf("%03d: ", i + offset);
					printf("(% 5.3f, % 5.3f) ", z1[i].re, z1[i].im);
					printf("(% 5.3f, % 5.3f) ", z2[i].re, z2[i].im);
					printf("(% 5.3f, % 5.3f) ", z3[i].re, z3[i].im);
					printf("%9.3e\n", d);
				}
			}
		}
	}
	if (info_flag && puid == 0) {
		printf("init start: %10ld\n"
			   "init end:   %10ld\n"
			   "calc start: %10ld\n"
			   "calc end:   %10ld\n"
			   "init time:  %10ld\n"
			   "calc time:  %10ld\n",
			   (long)init_start_time, (long)init_end_time,
			   (long)calc_start_time, (long)calc_end_time,
			   (long)(init_end_time - init_start_time),
			   (long)(calc_end_time - calc_start_time));
	}

	/* MPI_Finalize(); */
	return 0;
}

void fft(complex *z, complex *tmp_buf, int n, const complex* rotate_tbl,
		 int punum, int puid)
{
	int offset;           /* 0 <= (i + offset) < size */
	int n_shift;          /* (1 << n_shift) == n */
	int size_shift;       /* (1 << size_shift) == size */
	int punum_shift;      /* (1 << punum_shift) == punum */
	int size_div_4_shift; /* (1 << size_div_4_shift) == size / 4 */
	int size_div_4_mask;  /* (x & size_div_4_mask) == x % (size / 4) */
	int size_div_i_shift; /* (1 << size_div_i_shift) == size / i */
	int i, j;
	n_shift = 0;
	i = n;
	while (i > 1) {
		i >>= 1;
		n_shift++;
	}
	size_shift = 0;
	i = n * punum;
	while (i > 1) {
		i >>= 1;
		size_shift++;
	}
	punum_shift = 0;
	i = punum;
	while (i > 1) {
		i >>= 1;
		punum_shift++;
	}
	offset = (puid << n_shift);
	size_div_4_shift = size_shift - 2;
	size_div_4_mask = (1 << size_div_4_shift) - 1;
	size_div_i_shift = 0;
	{
		const int n_div_2 = (n >> 1);
		const int size_mask = (1 << size_shift) - 1;
		complex* const z_half = &(z[n_div_2]);
		const int comm_blocksize = n_div_2 * sizeof(z[0]);
		for (; size_div_i_shift < punum_shift; size_div_i_shift++) {
			int partner = puid ^ (1 << (punum_shift - size_div_i_shift - 1));
			i = (1 << (size_shift - size_div_i_shift));
			if (puid < partner) {
				MPI_Status dummy;
				MPI_Send(z_half, comm_blocksize, MPI_CHAR, partner,
						 (size_div_i_shift << 1), MPI_COMM_WORLD);
				MPI_Recv(tmp_buf, comm_blocksize, MPI_CHAR, partner,
						 (size_div_i_shift << 1), MPI_COMM_WORLD, &dummy);
				for (j = 0; j < n_div_2; j++) {
					data_t t1x, t1y, t2x, t2y;
					data_t rx, ry;
					{
						const int r_idx =
							((j + offset) << size_div_i_shift) & size_mask;
						const int r_idx_upper = (r_idx >> size_div_4_shift);
						const int r_idx_lower = (r_idx & size_div_4_mask);
						switch (r_idx_upper) {
						case 0:
							rx = rotate_tbl[r_idx_lower].re;
							ry = rotate_tbl[r_idx_lower].im;
							break;
						case 1:
							rx = -rotate_tbl[r_idx_lower].im;
							ry = rotate_tbl[r_idx_lower].re;
							break;
						case 2:
							rx = -rotate_tbl[r_idx_lower].re;
							ry = -rotate_tbl[r_idx_lower].im;
							break;
						default: /* case 3 */
							rx = rotate_tbl[r_idx_lower].im;
							ry = -rotate_tbl[r_idx_lower].re;
							break;
						}
					}
					t1x = z[j].re, t1y = z[j].im;
					t2x = tmp_buf[j].re, t2y = tmp_buf[j].im;
#ifdef DEBUG
					printf("i:%02d - r[%02d]:(%+4.2f,%+4.2f)"
						   " z[%02d]:(%+4.2f,%+4.2f), z[%02d]:(%+4.2f,%+4.2f)"
						   " - %02d/%02d: %02d\n",
						   i,
						   ((j + offset) << size_div_i_shift) & size_mask,
						   rx, ry,
						   j + offset, t1x, t1y,
						   j + offset + i / 2, t2x, t2y,
						   puid, punum, partner);
#endif
					z[j].re = t1x + t2x;
					z[j].im = t1y + t2y;
					t2x = t1x - t2x;
					t2y = t1y - t2y;
					tmp_buf[j].re = t2y * ry + t2x * rx;
					tmp_buf[j].im = t2y * rx - t2x * ry;
				}
				MPI_Send(tmp_buf, comm_blocksize, MPI_CHAR, partner,
						 (size_div_i_shift << 1) + 1, MPI_COMM_WORLD);
				MPI_Recv(z_half, comm_blocksize, MPI_CHAR, partner,
						 (size_div_i_shift << 1) + 1, MPI_COMM_WORLD, &dummy);
			} else {
				const int partner_offset = (partner << n_shift);
				MPI_Status dummy;
				MPI_Send(z, comm_blocksize, MPI_CHAR, partner,
						 (size_div_i_shift << 1), MPI_COMM_WORLD);
				MPI_Recv(tmp_buf, comm_blocksize, MPI_CHAR, partner,
						 (size_div_i_shift << 1), MPI_COMM_WORLD, &dummy);
				for (j = 0; j < n_div_2; j++) {
					data_t t1x, t1y, t2x, t2y;
					data_t rx, ry;
					int k = j + n_div_2;
					{
						const int r_idx =
							((k + partner_offset) << size_div_i_shift)
							& size_mask;
						const int r_idx_upper = (r_idx >> size_div_4_shift);
						const int r_idx_lower = (r_idx & size_div_4_mask);
						switch (r_idx_upper) {
						case 0:
							rx = rotate_tbl[r_idx_lower].re;
							ry = rotate_tbl[r_idx_lower].im;
							break;
						case 1:
							rx = -rotate_tbl[r_idx_lower].im;
							ry = rotate_tbl[r_idx_lower].re;
							break;
						case 2:
							rx = -rotate_tbl[r_idx_lower].re;
							ry = -rotate_tbl[r_idx_lower].im;
							break;
						default: /* case 3 */
							rx = rotate_tbl[r_idx_lower].im;
							ry = -rotate_tbl[r_idx_lower].re;
							break;
						}
					}
					t1x = tmp_buf[j].re, t1y = tmp_buf[j].im;
					t2x = z[k].re, t2y = z[k].im;
#ifdef DEBUG
					printf("i:%02d - r[%02d]:(%+4.2f,%+4.2f)"
						   " z[%02d]:(%+4.2f,%+4.2f), z[%02d]:(%+4.2f,%+4.2f)"
						   " - %02d/%02d: %02d\n",
						   i,
						   ((k + partner_offset) << size_div_i_shift)
						   	& size_mask,
						   rx, ry,
						   k + partner_offset, t1x, t1y,
						   k + partner_offset + i / 2, t2x, t2y,
						   puid, punum, partner);
#endif
					tmp_buf[j].re = t1x + t2x;
					tmp_buf[j].im = t1y + t2y;
					t2x = t1x - t2x;
					t2y = t1y - t2y;
					z[k].re = t2y * ry + t2x * rx;
					z[k].im = t2y * rx - t2x * ry;
				}
				MPI_Send(tmp_buf, comm_blocksize, MPI_CHAR, partner,
						 (size_div_i_shift << 1) + 1, MPI_COMM_WORLD);
				MPI_Recv(z, comm_blocksize, MPI_CHAR, partner,
						 (size_div_i_shift << 1) + 1, MPI_COMM_WORLD, &dummy);
			}
		}
	}
	{
		int k1, k2;
		for (; size_div_i_shift < size_shift; size_div_i_shift++) {
			i = (1 << (size_shift - size_div_i_shift));
			for (j = 0; j < i / 2; j++) {
				data_t rx, ry;
				{
					const int r_idx = (j << size_div_i_shift);
					const int r_idx_upper = (r_idx >> size_div_4_shift);
					const int r_idx_lower = (r_idx & size_div_4_mask);
					switch (r_idx_upper) {
					case 0:
						rx = rotate_tbl[r_idx_lower].re;
						ry = rotate_tbl[r_idx_lower].im;
						break;
					case 1:
						rx = -rotate_tbl[r_idx_lower].im;
						ry = rotate_tbl[r_idx_lower].re;
						break;
					case 2:
						rx = -rotate_tbl[r_idx_lower].re;
						ry = -rotate_tbl[r_idx_lower].im;
						break;
					default: /* case 3 */
						rx = rotate_tbl[r_idx_lower].im;
						ry = -rotate_tbl[r_idx_lower].re;
						break;
					}
				}
				for (k1 = j; k1 < n; k1 += i) {
					data_t t1x, t1y, t2x, t2y;
					k2 = k1 + i / 2;
					t1x = z[k1].re, t1y = z[k1].im;
					t2x = z[k2].re, t2y = z[k2].im;
#ifdef DEBUG
					printf("i:%02d - r[%02d]:(%+4.2f,%+4.2f)"
						   " z[%02d]:(%+4.2f,%+4.2f), z[%02d]:(%+4.2f,%+4.2f)"
						   " - %02d/%02d: lo\n",
						   i, (j << size_div_i_shift), rx, ry,
						   k1 + offset, t1x, t1y, k2 + offset, t2x, t2y,
						   puid, punum);
#endif
					z[k1].re = t1x + t2x;
					z[k1].im = t1y + t2y;
					t2x = t1x - t2x;
					t2y = t1y - t2y;
					z[k2].re = t2y * ry + t2x * rx;
					z[k2].im = t2y * rx - t2x * ry;
				}
			}
		}
	}
	{
		const int size = (1 << size_shift);
		for (i = 0; i < n; i++) z[i].re /= size, z[i].im /= size;
	}
}

void ufft(complex *z, complex *tmp_buf, int n, const complex *rotate_tbl,
		  int punum, int puid)
{
}
