// Built by Peter A Noble Dec 16 2023 Email: panoble@gmail.com
// Copyright 2023

#include <fstream>
#include <string>
#include <iostream>
#include <math.h>
#include <cstdlib>
#include <float.h>
#include <complex>
#include <iomanip>
#include <cmath>
#include <vector>

// g++ cnn.cpp -o cnn
// ./cnn 64 28 84 /Users/peternoble/Desktop/mnist_png3/testing/2/77.txt
// Purpose is to show that I can extract equations from Pytorch for CNN model

using namespace std;

int main (int argc, char * const argv[]) {
//cout << "hello\n";

int filter_size = atoi(argv[1]);
int height=atoi(argv[2]);
int width=atoi(argv[3]);
int dim=height*width;
int outputs=10;
//double output[11];
ifstream in("weights_0.txt");  		//weights_0.txt
ifstream in2("biases_0.txt");  	//biases_0.txt
ifstream in3("weights_1.txt");  	//4_array.txt
ifstream in4("biases_1.txt");  	//biases_0.txt
ifstream in5(argv[4]);  	//4_array.txt
ofstream out(argv[5]); 		//out.txt

// Look for file errors

if (!in.is_open()) {
    cerr << "Error opening weights_0.txt" << endl;
    return 1; // or handle the error appropriately
}
if (!in2.is_open()) {
    cerr << "Error opening biases_0.txt" << endl;
    return 1; // or handle the error appropriately
}
if (!in3.is_open()) {
    cerr << "Error opening weights_1.txt" << endl;
    return 1; // or handle the error appropriately
}
if (!in4.is_open()) {
    cerr << "Error opening biases_1.txt" << endl;
    return 1; // or handle the error appropriately
}
if (!in5.is_open()) {
    cerr << "Error array.txt" << endl;
    return 1; // or handle the error appropriately
}

// end of looking for file errors

int columns=filter_size+1;
int rows=40;  // kernel size defines this.
double output[10+1];
int s,t,u,v,w,x;
double sum_of_e=0.0;
double* biases_0= new double[columns];	 //16 ==> 16
double* biases_1= new double[columns];   //16 ==> 10 
	for (int t=0;t< columns; t++)
		{
		biases_0[t]=0.0;
		biases_1[t]=0.0;
		}

double* weights_1= new double[3010561];  //16 ==> 100000
	for (int t=0;t< 3010561; t++)
		{
		weights_1[t]=0.0;
		}

int** array = new int*[300];
for (int width = 0; width < 300; width++) {
    array[width] = new int[300];
    for (int height = 0; height < 300; height++) {
        array[width][height] = 0.0;
    }
}

double*** convol = new double**[300];
    for (int q = 0; q < 300; q++) 
    	{
        convol[q] = new double*[300];
        for (int s = 0; s < 300; s++) 
        	{
            convol[q][s] = new double[300];
            for (int t = 0; t < 300; t++) 
            	{
                convol[q][s][t] = 0.0;
            	}
        	}
    	}

double*** pool = new double**[200];
    for (int q = 0; q < 200; q++) 
    	{
        pool[q] = new double*[100];
        for (int s = 0; s < 100; s++) 
        	{
        	pool[q][s] = new double[300];
            for (int t = 0; t < 300; t++) 
            	{
                pool[q][s][t] = 0.0;
            	}
        	}
    	}

// Assume number of kernels is filter_size.
// Assume kernel size is 2 x 2.
int kernel_max=filter_size+1;
double*** kernel = new double**[kernel_max];
    for (int q = 0; q < kernel_max; q++) 
    	{
        kernel[q] = new double*[20];
        for (int s = 0; s < 20; s++) 
        	{
            kernel[q][s] = new double[20];
            for (int t = 0; t < 20; t++) 
            	{
                kernel[q][s][t] = 0.0;
            	}
        	}
   		 }

double pool2[400000];
for (int t = 0; t < 400000; t++) 
  	{
     pool2[t] = 0.0;
    }

int k,l,z;
double weights_0;

k=2;l=0;
 for (int l=0;l< (filter_size); l++)  //kernel size is: filter_size divided by 4
	{
 	for (int x=0;x< k; x++)
		{
		for (int w=0;w<k; w++)
			{
			in >>  weights_0;
			kernel[l][x][w]=weights_0;
			}
		}
	}	

//cout << "l=\t" << kernel[127][0][0] << "\n" << flush; exit(1);

s=0;
while(!in2.eof())
	{
	in2 >> biases_0[s];
	s=s+1;
	}
s=s-1;

u=0;
while(!in3.eof())
	{
	in3 >> weights_1[u];
	u=u+1;
	}
u=u-1;

v=0;
while(!in4.eof())
	{
	in4 >> biases_1[v];
	v=v+1;
	}
v=v-1;

// read in test array

x=0;w=0;
for (int w=0;w<height; w++)  // max height is 30
	{
 	for (int x=0;x< width; x++)  // max width is 84
		{
		in5 >> array[x+1][w+1]; ///w < 100 x < 30
		}
	}

//Start program

double temp=0.0; int count=0;
int a=0,b=0,d=0,f=0; 
	 for (int filters=0;filters< filter_size; filters++)  // 4 filters
//	 for (int filters=0;filters< 1; filters++)  // 4 filters
		{
		 for (int row=0;row< height+1; row++)  // move 2 x 2 28 kernel down 1 row at at time
			{
			 for (int column=0;column< width+1; column++)  // move kernel 1 column at a time
				{
			    temp=0.0;
				temp =  (array[column][row]*kernel[filters][0][0])+temp; 
				temp =  (array[column+1][row]*kernel[filters][0][1])+temp; 
				temp =  (array[column][row+1]*kernel[filters][1][0])+temp; 
				temp =  (array[column+1][row+1]*kernel[filters][1][1])+temp; 
	 			temp= temp + biases_0[filters];
				if (temp<0) {temp=0;}
	 			convol[filters][column][row]= temp; 
	 			}
			}
 		}
 
// Start Deallocate memory for the 3D array

for (int q = 0; q < kernel_max; q++) 
	{
    for (int s = 0; s < 20; s++) 
    	{
        delete[] kernel[q][s];
    	}
    delete[] kernel[q];
	}
delete[] kernel;

for (int width = 0; width < 300; width++) {
    delete[] array[width];
}
delete[] array;

// End Deallocate memory for arrays

// Pool each layer
double	condition, max;

int col=0; int rows1=0;int r;
 for (int r=0;r< (filter_size); r++)  // move 2 x 2 28 kernel down 1 row at at time
	{
	for (int row = 0; row < height; row += 2) 
		{
		 for (int column=0;column< width; column += 2)  // move kernel 2 column at a time
			{
		    condition=0.0;max=-9999999;
			condition=convol[r][column][row];if (condition>max)  {max=condition;}
			condition=convol[r][column+1][row];if (condition>max)  {max=condition;}
			condition=convol[r][column][row+1];if (condition>max)  {max=condition;}
			condition=convol[r][column+1][row+1];if (condition>max)  {max=condition;}
//			if (max<0.0) {max=0.0;}
//			pool[r][column][row]=max;
			pool2[rows1]=max;
			rows1=rows1+1;
			}
	   	}
	}

delete[] convol;

int c=0;
for (int r = 0; r < outputs; r++) // 10 outputs
	{
	temp=0.0;
 	for (int s1=0;s1< rows1; s1++)  // move 2 x 2 28 kernel down 1 row at at time
		{
		temp=(pool2[s1]*weights_1[c])+temp; c=c+1;
		}
//	if (temp<0) {temp=0.0;}
	output[r]= temp+biases_1[r];	
//	cout << output[r] << "\n";
	}
//cout << "\n" << flush;

double max_output = output[0];

// Find the maximum value in the output array
for (int e = 0; e < outputs; e++) 
	{
    if (output[e] > max_output) 
	    {
	     max_output = output[e];
	    }
	}
//cout << "max_output=\t" << max_output << "\n"; //exit(1);
// Subtract max_output only if it's greater than a threshold

double threshold = -1.0e+300;  // Adjust the threshold as needed
if (max_output > threshold) 
	{
    for (int e = 0; e < 10; e++) 
    	{
        output[e] = exp(output[e] - max_output);
        sum_of_e += output[e];
    	}
	}
cout << "Digit\t\t" << "Probability\n";
// Normalize
for (int e = 0; e < 10; e++) 
	{
    if (sum_of_e != 0.0) 
    	{
        output[e] /= sum_of_e;
   		 }
	    cout << "Digit[" << e << "]=\t" << output[e] << "\n";
	}

return 0;
}


