#include "interval.h"
#include "matmuls.h"
#include <stdio.h>
#include <stdlib.h>

/* Return minimum value of an array of n doubles */
double min(double A[], int n) {
	int a;
	double m = A[0];
	for (a = 0; a < n; a++)
		m = A[a] < m ? A[a] : m;

	return m;
}

/* Return maximum value of an array of n doubles */
double max(double A[], int n) {
	int a;
	double m = A[0];
	for (a = 0; a < n; a++)
		m = A[a] > m ? A[a] : m;
	
	return m;
}

/* Calculate sum of an array of n doubles */
double sum(double A[], int n) {
	int a;
	double s = 0;
	for (a = 0; a < n; a++)
		s += A[a];

	return s;
}

/* Calculate arithmetic mean of an array of n doubles, ie.: 1/n * sum(A) */
double mean(double A[], int n) {
	return sum(A, n) / n;
}

/* Calculate standard deviaton of an array of n doubles,
 * ie.: 1/n * sum(A - mean(A)) */
double sd(double A[], int n) {
	int a;
	double s = 0, m = mean(A, n);
	for (a = 0; a < n; a++)
		s += A[a] - m;

	return (s * s) / n;
}

/* Return wall clock time and user cpu time for one matrix multiplication of
 * given size. Expects two matrices read for multiplication. */
void timeAMatmulsCall(void (*f)(double *, double *, double *, int, int, int),
	int n, double A[], double B[], double C[], double *wall, double *user) {
	double w, u, s;
	
	interval i;
	i = newInterval();

	f(A, B, C, n, n, n);
	timeInterval(i, &w, &u, &s);
	*wall += w;
	*user += u;
}

/* Run a series of tests on a matrix multiplication function and print out
 * statistics. */ 
void timeAMatmulsFunction(void (*f)(double *, double *, double *, int, int,
	int)) {
	int a, b, c;
	//int n, sizes[] = { 20, 30, 40 };
	//int n, sizes[] = { 320, 480, 640, 960 };
	int n, sizes[] = { 20, 30, 40, 60, 80, 120, 160, 240, 320, 
		480, 640, 960 };
	double w, u, *A, *B, *C, flops, resultsw[12], resultsu[12];

	n = sizeof(sizes) / sizeof(sizes[0]);
	A = malloc(sizeof(double) * 960 * 960);
	B = malloc(sizeof(double) * 960 * 960);
	C = malloc(sizeof(double) * 960 * 960);
	if (A == NULL || B == NULL || C == NULL) {
		printf("could not allocate memory\n");
		return;
	}

	for (a = 0; a < n; a++) {
		/* Generate two square matrices with random floating point
		 * values. */
		for (b = 0; b < sizes[a]; b++) {
			for (c = 0; c < sizes[a]; c++) {
				A[b*sizes[a]+c] = drand48();
				B[b*sizes[a]+c] = drand48();
			}
		}
		w = u = 0;
		
		/* Run test until 1 second of cpu time has elapsed. */
		for (b = 0; u < 1; b++)
			timeAMatmulsCall(f, sizes[a], A, B, C, &w, &u);

		/* Calculate number of floating point operations using number
		 * of iterations (b) times number of multiplications plus
		 * additions. */
		flops = b * (sizes[a] * sizes[a] * sizes[a] +
			sizes[a] * (sizes[a] - 1) * sizes[a]);
		resultsw[a] = w / flops;
		resultsu[a] = u / flops;
				
	}
	printf("%.3e %.3e %.3e %.3e ", min(resultsw, n), mean(resultsw, n),
		max(resultsw, n), sd(resultsw, n));
	printf("%.3e %.3e %.3e %.3e\n", min(resultsu, n), mean(resultsu, n),
		max(resultsu, n), sd(resultsu, n));
	free(A);
	free(B);
	free(C);
}

/* demonstrate correctness of a matrix multiplication function
 * by comparing its output to a known good version. */
void verify() {
	int b, c, n = 5;
	double *A, *B, *C, *D;

	A = malloc(sizeof(double) * 960 * 960);
	B = malloc(sizeof(double) * 960 * 960);
	C = malloc(sizeof(double) * 960 * 960);
	D = malloc(sizeof(double) * 960 * 960);
	if (A == NULL || B == NULL || C == NULL || D == NULL) {
		printf("could not allocate memory\n");
		return;
	}

	/* Generate two square matrices with random floating point
	 * values. */
	for (b = 0; b < n; b++) {
		for (c = 0; c < n; c++) {
			A[b*n+c] = drand48();
			B[b*n+c] = drand48();
		}
	}

	/* do the matrix multiplication. first is reference function, second is
 	 * function to be tested. */
	MatMul_21(A, B, C, n, n, n);
	MatMul_22(A, B, D, n, n, n);
	for (b = 0; b < n; b++) {
		for (c = 0; c < n; c++) {
			printf("%.3e ", C[b*n+c] - D[b*n+c]);
		}
		printf("\n");
	}
}

/* Produce a table of statistics for given matrix multiplication functions.
 * The left half of the table will be for wall clock time, the right half */
int main() {

	// verify(); 

	printf("matrix multiplication measurements\n");
	
	printf("\nwall: min   mean      max       sdev ");
	printf("   cpu: min    mean      max       sdev\n");

	timeAMatmulsFunction(MatMul_1);
	timeAMatmulsFunction(MatMul_3);
	timeAMatmulsFunction(MatMul_19);
	timeAMatmulsFunction(MatMul_21);

	timeAMatmulsFunction(MatMul_22);

	return 0;
}
