#include <stdlib.h>
#include <stdio.h>
#include <unistd.h>
#include <string.h>
#include <curses.h>
#include "vPSRAM.h"
#include "vDISP.h"
#include "vUART.h"
#include "vSPI.h"
#include "vSD.h"
#include "soc.h"
#include "cpu.h"

static const uint8_t *mRomAll, *mRomCurPage;
static uint32_t mRomSz, mRamSz;
static uint8_t *mRam, *mRamSta;
static struct VPSRAM *mPsram0, *mPsram1;
static struct VDISP *mDISP;
static struct VUART *mUART;
static struct VSD *mVSD;

#define I4004_ROM_DIRECT_SIZE		4096

//4004 things
#define DCL1_TMP32A		(1 * 256 + 0xB0)
#define DCL1_TMP32B		(1 * 256 + 0xA8)
#define DCL1_TMP32C		(1 * 256 + 0xA0)
#define DCL1_PC			(1 * 256 + 0x00)
#define DCL1_EPC		(1 * 256 + 0x60)
#define DCL1_NPC		(1 * 256 + 0x08)
#define DCL1_HI			(1 * 256 + 0x78)
#define DCL1_LO			(1 * 256 + 0x70)
#define DCL1_CAUSE		(1 * 256 + 0x10)
#define DCL1_BADVA		(1 * 256 + 0x80)
#define DCL1_ENTRYHI	(1 * 256 + 0x90)
#define DCL1_ENTRYLO	(1 * 256 + 0x98)
#define DCL1_STATUS		(1 * 256 + 0x18)
#define DCL1_INSTR		(1 * 256 + 0x68)

#define DLC1_HASH_PRESENT		(1 * 256 + 0x20)
#define DLC1_HASH_HEADS			(1 * 256 + 0x40)

#define NUM_TLB_ENTRIES			16

//mips things
#define MIPS_REG_AT	1	//assembler use (caller saved)
#define MIPS_REG_V0	2	//return val 0 (caller saved)
#define MIPS_REG_V1	3	//return val 1 (caller saved)
#define MIPS_REG_A0	4	//arg 0 (callee saved)
#define MIPS_REG_A1	5	//arg 1 (callee saved)
#define MIPS_REG_A2	6	//arg 2 (callee saved)
#define MIPS_REG_A3	7	//arg 3 (callee saved)

#define H_CONSOLE_WRITE						1
#define H_STOR_GET_SZ						2
#define H_STOR_READ							3
#define H_STOR_WRITE						4


uint32_t numMipsCycles = 0;
bool report = false;
uint64_t totalCy = 0;

uint8_t cpuExtRomPortRead(void)
{
	enum PinOutState r0, r1;
	uint8_t ret = 0;

	if (vspiPinRead(vsdGetVSPI(mVSD)) == PinHigh)
		ret += 8;

	if (mDISP && vspiPinRead(vdispGetVSPI(mDISP)) == PinHigh)
		ret += 2;

	if (mUART && vspiPinRead(vuartGetVSPI(mUART)) == PinHigh)
		ret += 4;

	r0 = mPsram0 ? vspiPinRead(vpsramGetVSPI(mPsram0)) : PinHiZ;
	r1 = mPsram1 ? vspiPinRead(vpsramGetVSPI(mPsram1)) : PinHiZ;

	if (r0 != PinHiZ && r1 != PinHiZ && r0 != r1) {	//allow same output for initial boot up case

		fprintf(stderr, "PSRAMs fighing over the bus\n");
		exit(-1);
	}
	if (r0 == PinHigh || r1 == PinHigh)
		ret += 1;
	
	return ret;
}

uint8_t cpuExtRomRead(uint_fast16_t addr)
{
	if (addr < I4004_ROM_DIRECT_SIZE)
		return mRomCurPage[addr];

	fprintf(stderr, "ROM READ FAIL OF 0x%03x\n", (unsigned)addr);
	exit(-1);
	return 0;
}

uint8_t socGetRomPage(void)
{
	return mRomCurPage == mRomAll ? 0 : 1;
}

static void socPrvUpdteShownPc(uint_fast8_t idx, uint_fast8_t val)		//we assume in-order update...
{
	static uint32_t mShownPc;
	uint32_t shift = 4 * idx;
	uint32_t mask = 0x0f << shift;
	uint32_t orr = ((uint32_t)val) << shift;

	mShownPc = (mShownPc &~ mask) | orr;

	if (idx == 7) {
		uint32_t i, v = mShownPc;

		for (i = 0; i < 32; i++, v <<= 1)
			mvaddch(5, i + 3, (v >> 31) ? 'X' : '-');

		refresh();
	}
}

void cpuExtRamPortWrite(uint_fast8_t chipIdx, uint_fast8_t val)			//only nibble written
{
	switch (chipIdx) {
		case 0 ... 3:
			socPrvUpdteShownPc(chipIdx - 0, val);
			break;

		case 8 ... 11:
			socPrvUpdteShownPc(chipIdx - 4, val);
			break;

		case 4:		//SD (chip 0 in dcl 1) and rom bank select
	
			if (mRomSz > I4004_ROM_DIRECT_SIZE)
				mRomCurPage = mRomAll + I4004_ROM_DIRECT_SIZE * (!(val & 8) ? 1 : 0);

			vspiPinsWritten(vsdGetVSPI(mVSD), !(val & 1), !(val & 2), !(val & 4));
			break;

		case 5:		//UART and VFD (chip 1 in dcl 1)

			if (mDISP)
				vspiPinsWritten(vdispGetVSPI(mDISP), !(val & 1), !(val & 2), !(val & 4));	//selected by pin 2
			
			if (mUART)
				vspiPinsWritten(vuartGetVSPI(mUART), !(val & 1), !(val & 2), !(val & 8));	//selected by pin 3

			break;

		case 6:		//spi ram (chip 2 in dcl 1)
		
			if (mPsram0)
				vspiPinsWritten(vpsramGetVSPI(mPsram0), !(val & 1), !(val & 2), !(val & 4));
			if (mPsram1)
				vspiPinsWritten(vpsramGetVSPI(mPsram1), !(val & 1), !(val & 2), !(val & 8));
			break;

		default:
			fprintf(stderr, "unexpected write to chip %u port\n", chipIdx);
			exit(-1);
	}
}

uint8_t cpuExtStatusByteRead(uint_fast16_t addr)
{
	if (addr < mRamSz / 4) {

		return mRamSta[addr] & 15;
	}
	else {

		fprintf(stderr, "RAM STA READ FAIL OF 0x%03x\n", (unsigned)addr);
		exit(-1);
		return 0;
	}
}

void cpuExtStatusByteWrite(uint_fast16_t addr, uint_fast8_t val)
{
	if (addr < mRamSz / 4) {

		mRamSta[addr] = val & 15;
	}
	else {

		fprintf(stderr, "RAM STA WRITE FAIL OF 0x%03x\n", (unsigned)addr);
		exit(-1);
		return;
	}
}

uint8_t cpuExtRamRead(uint_fast16_t addr)
{
	if (addr >= 0x200 + 0x10 * NUM_TLB_ENTRIES)
		return 0;

	if (addr < mRamSz) {

		return mRam[addr] & 15;
	}
	else {

		fprintf(stderr, "RAM READ FAIL OF 0x%03x\n", (unsigned)addr);
		exit(-1);
		return 0;
	}
}

static uint32_t __attribute__((pure)) cpuPrvReadU32(uint_fast16_t ofst)
{
	uint32_t val = 0;
	uint_fast8_t i;
	
	for (i = 0; i < 8; i++)
		val = val * 16 + mRam[ofst + 7 - i] % 16;

	return val;
}

static void cpuPrvWriteU32(uint_fast16_t ofst, uint32_t val)
{
	uint_fast8_t i;
	
	for (i = 0; i < 8; i++, val >>= 4)
		mRam[ofst++] = val & 15;
}

void cpuExtRamWrite(uint_fast16_t addr, uint_fast8_t val)
{
	if (addr < mRamSz){

		if (addr < 8 && val) {
			fprintf(stderr, "zero reg written\n");
			exit(-1);
		}
	
		mRam[addr] = val & 15;
	}
	else {

		fprintf(stderr, "RAM WRITE FAIL OF 0x%03x\n", (unsigned)addr);
		exit(-1);
	}
}

static void cpuPrvPutcharReport(char chr)
{
	extern void ctl_cHandler(int v);
	static uint8_t state = 0;
	
	switch (state) {
		case 1: state = (chr == '\r') ? 2 : 0; break;
		case 2: state = (chr == '\n') ? 3 : 0; break;
		case 3: state = (chr == '#') ? 4 : 0; break;
		case 4: state = (chr == ' ') ? 5 : 0; break;
	}
	if (!state && chr == '.')
		state = 1;

	if (state == 5) {

		ctl_cHandler(0);
	}
}

static void cpuPrvTlbReport(void)
{
	uint_fast8_t i;

	for (i = 0; i < 31; i++) {
		fprintf(stderr, "TLBHASH[%2u]: ", i);

		if (!mRam[DLC1_HASH_PRESENT + i])
			fprintf(stderr, "NOT PRESENT\n");
		else if (mRam[DLC1_HASH_PRESENT + i] != 1)
			fprintf(stderr, "INVAL MARKER %u\n", mRam[DLC1_HASH_PRESENT + i]);
		else
			fprintf(stderr, "ENTRY %2u\n", mRam[DLC1_HASH_HEADS + i]);
	}

	for (i = 0; i < NUM_TLB_ENTRIES; i++) {
		uint32_t hi = cpuPrvReadU32(2 * 256 + 16 * i + 8 * 0);
		uint32_t lo = cpuPrvReadU32(2 * 256 + 16 * i + 8 * 1);
		uint_fast8_t nextIdx = mRamSta[2 * 64 + 4 * i + 0];
		uint_fast8_t prevIdx = mRamSta[2 * 64 + 4 * i + 2] + 16 * (1 & mRamSta[2 * 64 + 4 * i + 3]);
		bool haveNext = mRamSta[2 * 64 + 4 * i + 1];
		bool havePrev = !!(mRamSta[2 * 64 + 4 * i + 3] & 2);

		fprintf(stderr, "ENTRY %2u {0x%08x 0x%08x 0x%02x 0x%02x 0x%02x 0x%02x}. 0x%08x -> 0x%08x\n",
			i, hi, lo, mRamSta[2 * 64 + 4 * i + 0], mRamSta[2 * 64 + 4 * i + 1], mRamSta[2 * 64 + 4 * i + 2], mRamSta[2 * 64 + 4 * i + 3],
			hi & 0xfffff000, lo & 0xfffff000);
		
		fprintf(stderr, "  prev: ");
		if (havePrev)
			fprintf(stderr, "%-2u  ", prevIdx);
		else
			fprintf(stderr, "bk %-2u", prevIdx);
		
		fprintf(stderr, "  next: ");
		if (haveNext)
			fprintf(stderr, "%-2u  ", nextIdx);
		else
			fprintf(stderr, "NONE");
		
		fprintf(stderr, " ASID %2u, %c%c%c%c\n", (hi >> 6) & 0x3f, (lo & 0x100) ? 'G' : 'g', (lo & 0x200) ? 'V' : 'v', (lo & 0x400) ? 'D' : 'd', (lo & 0x800) ? 'N' : 'n');
	}
}

static void cpuPrvTlbCheck(void)
{
	bool success = true;
	uint_fast8_t i, j;
	uint16_t entriesHit = 0;

	//verify all heads point to entries that point back to them
	//walk each chain verifying that we hit each entry one
	for (i = 0; i < 31; i++) {

		if (!mRam[DLC1_HASH_PRESENT + i])
			continue;
		else if (mRam[DLC1_HASH_PRESENT + i] != 1) {

			fprintf(stderr, "HEAD %u has invalid present marker %u\n", i, mRam[DLC1_HASH_PRESENT + i]);
			success = false;
		}
		else {

			uint_fast8_t entryIdx = mRam[DLC1_HASH_HEADS + i], prevEntryIdx;
			uint32_t lo = cpuPrvReadU32(2 * 256 + 16 * entryIdx + 8 * 1);

			uint_fast8_t nextIdx = mRamSta[2 * 64 + 4 * entryIdx + 0];
			uint_fast8_t prevIdx = mRamSta[2 * 64 + 4 * entryIdx + 2] + 16 * (1 & mRamSta[2 * 64 + 4 * entryIdx + 3]);
			bool haveNext = mRamSta[2 * 64 + 4 * entryIdx + 1];
			bool havePrev = !!(mRamSta[2 * 64 + 4 * entryIdx + 3] & 2);


			if (havePrev) {
				fprintf(stderr, "HEAD %u points to entry %u, but that entry claims to have a previous entry\n", i, entryIdx);
				success = false;
			}
			else if (prevIdx != i) {
				fprintf(stderr, "HEAD %u points to entry %u, but that entry claims to belong to bucket %u\n", i, entryIdx, prevIdx);
				success = false;
			}
			else {

				int16_t expectedPrev = -1;

				while (success) {

					uint32_t hi = cpuPrvReadU32(2 * 256 + 16 * entryIdx + 8 * 0);
					uint32_t lo = cpuPrvReadU32(2 * 256 + 16 * entryIdx + 8 * 1);
					uint_fast8_t nextIdx = mRamSta[2 * 64 + 4 * entryIdx + 0];
					uint_fast8_t prevIdx = mRamSta[2 * 64 + 4 * entryIdx + 2] + 16 * (1 & mRamSta[2 * 64 + 4 * entryIdx + 3]);
					bool haveNext = mRamSta[2 * 64 + 4 * entryIdx + 1];
					bool havePrev = !!(mRamSta[2 * 64 + 4 * entryIdx + 3] & 2);


					if (entriesHit & (1 << entryIdx)) {
						fprintf(stderr, "entry %u already seen\n", entryIdx);
						success = false;
					}
					else {
						entriesHit |= 1 << entryIdx;
					
						if (havePrev && expectedPrev < 0) {
							fprintf(stderr, "entry %u has prev, unexpectedly\n", entryIdx);
							success = false;
						}
						else if (!havePrev && expectedPrev >= 0) {
							fprintf(stderr, "entry %u has no prev, but expected to point to %u\n", entryIdx, expectedPrev);
							success = false;
						}
						else if (havePrev && prevIdx != expectedPrev){
							fprintf(stderr, "entry %u claims prev of %u, but expected to point to %u\n", entryIdx, prevIdx, expectedPrev);
							success = false;
						}
						else if (!haveNext) {
							break;
						}
						else {

							expectedPrev = entryIdx;
							entryIdx = nextIdx;
						}
					}
				}
			}
		}
	}

	if (entriesHit != (1 << NUM_TLB_ENTRIES) - 1) {

		for (i = 0; i < NUM_TLB_ENTRIES; i++) {
			if (!(entriesHit & (1 << i)))
				fprintf(stderr, "entry %u not traversed\n", i);
		}
		success = false;
	}

	//verify no duplicate VAs in translateable space
	for (i = 0; i < NUM_TLB_ENTRIES; i++) {

		uint32_t hi_i = cpuPrvReadU32(2 * 256 + 16 * i + 8 * 0);
		uint32_t lo_i = cpuPrvReadU32(2 * 256 + 16 * i + 8 * 1);
		uint32_t va_i = hi_i & 0xfffff000;
		uint8_t asid_i = (hi_i >> 6) & 0x3f;
		bool global_i = !!(lo_i & 0x100);

		if (va_i >= 0x80000000 && va_i < 0xc0000000)
			continue;

		for (j = 0; j < NUM_TLB_ENTRIES; j++) {

			if (i == j)
				continue;

			uint32_t hi_j = cpuPrvReadU32(2 * 256 + 16 * j + 8 * 0);
			uint32_t lo_j = cpuPrvReadU32(2 * 256 + 16 * j + 8 * 1);
			uint32_t va_j = hi_j & 0xfffff000;
			uint8_t asid_j = (hi_j >> 6) & 0x3f;
			bool global_j = !!(lo_i & 0x100);

			if (va_i == va_j && ((global_i && global_j) || (asid_i == asid_j))) {
				fprintf(stderr, "entry %u and %u have a VA conflict\n", i, j);
				success = false;
			}
		}
	}


	if (!success) {
		fprintf(stderr, "@ mips cycles: %u with PC 0x%08x\n", numMipsCycles, cpuPrvReadU32(DCL1_PC));
		cpuPrvTlbReport();
		exit(-1);
	}
}

static void cpuPrvReportMipsState(bool ooo)
{
	unsigned i;
	static const char regNames[][3] = {
		"$0", "at", "v0", "v1", "a0" ,"a1", "a2", "a3",
		"t0", "t1", "t2", "t3", "t4", "t5", "t6", "t7",
		"s0", "s1", "s2" ,"s3", "s4", "s5", "s6", "s7",
		"t8", "t9", "k0", "k1", "gp", "sp", "fp", "ra",
		"PC", "HI", "LO", "##",
	};
	static const uint16_t regAddrs[] = {
		8 * 0, 8 * 1, 8 * 2, 8 * 3, 8 * 4, 8 * 5, 8 * 6, 8 * 7, 
		8 * 8, 8 * 9, 8 * 10, 8 * 11, 8 * 12, 8 * 13, 8 * 14, 8 * 15,
		8 * 16, 8 * 17, 8 * 18, 8 * 19, 8 * 20, 8 * 21, 8 * 22, 8 * 23,
		8 * 24, 8 * 25, 8 * 26, 8 * 27, 8 * 28, 8 * 29, 8 * 30, 8 * 31,
		DCL1_PC, DCL1_HI, DCL1_LO, DCL1_INSTR,
	};
	uint32_t pc = cpuPrvReadU32(DCL1_PC);
	static uint64_t prevCy = 0;
	static uint32_t totalInstrs = 0;
	uint64_t curCy = cpuGetCy();
	uint32_t prevInstrCy = 0, avgInstrCy = 0;

	if (!ooo) {
		if (prevCy) {
			prevInstrCy = curCy - prevCy;
			totalCy += prevInstrCy;
			totalInstrs++;
			avgInstrCy = (totalCy + totalInstrs / 2) / totalInstrs;
		}
		numMipsCycles++;
		prevCy = curCy;
	}

	cpuPrvTlbCheck();

	if (ooo || report) {

		fprintf(stderr, "MIPS [0x%08x] = 0x%08x  //next pc 0x%08x (prev instr used %u cy, avg %u cy, %u MIPS instrs executed)\r\n",
				pc, cpuPrvReadU32(DCL1_TMP32C), cpuPrvReadU32(DCL1_NPC), prevInstrCy, avgInstrCy, numMipsCycles);

		for (i = 0; i < 32; i++) {
			fprintf(stderr, "  %s=0x%08x", regNames[i], cpuPrvReadU32(regAddrs[i]));
			if (i % 8 == 7)
				fprintf(stderr, "\r\n");
		}
		fprintf(stderr, "  HI=0x%08x  LO=0x%08x  STATUS=0x%08x CAUSE=0x%08x, BadVA=0x%08x, ENTRY{hi 0x%08x  lo 0x%08x}\r\n",
			cpuPrvReadU32(DCL1_HI), cpuPrvReadU32(DCL1_LO), cpuPrvReadU32(DCL1_STATUS), cpuPrvReadU32(DCL1_CAUSE),
			cpuPrvReadU32(DCL1_BADVA), cpuPrvReadU32(DCL1_ENTRYHI), cpuPrvReadU32(DCL1_ENTRYLO));

		cpuPrvTlbReport();
	}

}

uint8_t cpuExtHyper(uint_fast8_t which, uint_fast8_t A, uint_fast8_t C)
{
	switch (which) {
		case 3:		//report mips instr
			cpuPrvReportMipsState(false);
			break;

		case 7:		//putchar(r0:r1)		this is only here for boot time collection
			cpuPrvPutcharReport(cpuRegRead(0) * 16 + cpuRegRead(1));		//for boot timing
			break;

		default:
			fprintf(stderr, "UNKNOWN HYPERCALL %u\n", which);
			exit(-1);
	}

	return 16 * C + A;
}

static bool socPrvVsdSecRead(void *userData, uint32_t sec, uint8_t *dst)
{
	(void)userData;

	return socExtSdSecRead(sec, dst);
}

static bool socPrvVsdSecWrite(void *userData, uint32_t sec, const uint8_t *src)
{
	(void)userData;
	
	return socExtSdSecWrite(sec, src);
}

static void socPrvGetPsramTotalStats(struct PsramStats *dst)
{
	struct PsramStats src0 = {}, src1 = {};

	if (mPsram0)
		vpsramGetStats(mPsram0, &src0);
	if (mPsram1)
		vpsramGetStats(mPsram1, &src1);

	dst->numTimesSelected = src0.numTimesSelected + src1.numTimesSelected;
	dst->numCyclesSelected = src0.numCyclesSelected + src1.numCyclesSelected;
	dst->numTimesSelectedPage0 = src0.numTimesSelectedPage0 + src1.numTimesSelectedPage0;
	dst->longestSelectedDuration = src0.longestSelectedDuration > src1.longestSelectedDuration ? src0.longestSelectedDuration : src1.longestSelectedDuration;
}

static void socPrvGetPsramStatsByIndex(struct PsramStats *sta, int8_t which)
{
	if (which < 0)
		socPrvGetPsramTotalStats(sta);
	else if (which > 0) {

		if (mPsram1)
			vpsramGetStats(mPsram1, sta);
		else
			memset(sta, 0, sizeof(*sta));
	}
	else {

		if (mPsram0)
			vpsramGetStats(mPsram0, sta);
		else
			memset(sta, 0, sizeof(*sta));
	}
}

uint64_t socPrvGetPsramNumSelections(int8_t which)
{
	struct PsramStats sta;

	socPrvGetPsramStatsByIndex(&sta, which);

	return sta.numTimesSelected;
}

uint64_t socPrvGetPsramNumSelectionsPage0(int8_t which)
{
	struct PsramStats sta;

	socPrvGetPsramStatsByIndex(&sta, which);

	return sta.numTimesSelectedPage0;
}

uint64_t socPrvGetPsramSelectedTicks(int8_t which)
{
	struct PsramStats sta;

	socPrvGetPsramStatsByIndex(&sta, which);

	return sta.numCyclesSelected;
}

uint32_t socPrvGetPsramLongestSelectedTicks(int8_t which)
{
	struct PsramStats sta;

	socPrvGetPsramStatsByIndex(&sta, which);

	return sta.longestSelectedDuration;
}

bool socInit(const uint8_t *rom, uint32_t romSz, uint8_t *ram, uint8_t *ramSta, uint32_t ramSz)
{
	mRomAll = rom;
	mRam = ram;
	mRamSta = ramSta;
	mRomSz = romSz;
	mRamSz = ramSz;
	mRomCurPage = mRomAll + 4096;	//match real hw

	mPsram0 = vpsramInit("PSRAM0", 0x800000);
	mPsram1 = vpsramInit("PSRAM1", 0x800000);
	mVSD = vsdInit(socPrvVsdSecRead, socPrvVsdSecWrite, NULL, socExtSdGetSize(), false);
	mDISP = vdispInit();
	mUART = vuartInit();

	if (ramSz > 2048 || romSz > 8192)
		return false;

	return cpuInit();
}

void socRun(void)
{
	uint32_t i = 0;

	while(1) {
		i++;

		if (!(i & 0xff)) {
			vuartPeriodic(mUART);
			//here for speed
			cpuPrvSetTestSignal(vuartGetIrqPinVal(mUART));	//our onboard FET + resistor will convert (0V -> -10V, 3.3V -> 0V) so logical values are preserved. 1->1, 0->0. this means that "T" bit is active low UART interrupt
		}
		if (!(i & 0xfffff)) {

			uint64_t cy = cpuGetCy();
			uint32_t seconds = cy / (740000 / 8), days, hours, minutes;

			minutes = seconds / 60;
			seconds %= 60;
			hours = minutes / 60;
			minutes %= 60;
			days = hours / 24;
			hours %= 24;

			mvprintw(7, 0, "REALTIME: %u days, %02u:%02u:%02u  ", days, hours, minutes, seconds);
		}
		cpuRunInstr();
	}
}

