#include "stdafx.h"
#include "Seh64.h"
#include "Seh.h"

#include "Code/X64/Asm.h"

#if defined(WINDOWS) && defined(X64)

#include "Code/FnState.h"
#include "Gc/Gc.h"
#include "Gc/CodeTable.h"

namespace code {
	namespace eh {

		RUNTIME_FUNCTION *exceptionCallback(void *pc, void *base) {
			storm::CodeTable &table = storm::codeTable();

			void *found = table.find(pc);
			if (!found)
				return null;

			byte *code = (byte *)found;
			size_t size = storm::Gc::codeSize(code);

			Nat startOffset = Nat(size_t(code) - size_t(base));

			// Note: The format that is manipulated here matches what WindowsOutput creates for us
			// in its constructor. What is done here probably makes more sense in light of the code
			// there.

			// EH offset is stored at the end of the allocation:
			Nat ehOffset = *(Nat *)(code + size - sizeof(Nat));
			Nat dataStart = ehOffset + Nat(sizeof(RUNTIME_FUNCTION));

			// If pc is after where we will place the "EndAddress", then we simply return null. This
			// is to not accidentally confuse the unwinding logic if we get an exception during the
			// small shim in the end of the allocation.
			if (size_t(pc) > size_t(code) + ehOffset)
				return null;

			// Find and update the RUNTIME_FUNCTION in the function. Note: we need to update the
			// offsets inside it, since it might have moved since we used it last.
			RUNTIME_FUNCTION *fn = (RUNTIME_FUNCTION *)(code + ehOffset);
			fn->BeginAddress = startOffset;
			// Note: a bit too late, but it is cumbersome to figure out exactly where.
			fn->EndAddress = startOffset + ehOffset;
			fn->UnwindData = startOffset + dataStart;

			// Find and update the address of the handler function as well:
			UnwindInfo *uwInfo = (UnwindInfo *)&code[dataStart];
			Nat exCodeCount = roundUp(uwInfo->unwindCount, byte(2));
			*(DWORD *)(code + dataStart + sizeof(UnwindInfo) + exCodeCount*2)
				= startOffset + Nat(size) - 10;

			// For debugging:
			// PVAR(base);
			// PVAR(found);
			// PVAR((void *)(size_t(pc) - size_t(found)));
			// PVAR(fn);
			// for (Nat i = 0; i < size - ehOffset; i++) {
			// 	if (i % 8 == 0)
			// 		PNN(L"\n" << i << L":");
			// 	PNN(L" " << toHex(((byte *)fn)[i]));
			// }
			// PLN(L"");
			// PLN(L"");

			// PVAR((void *)(code + size - 10));

			return fn;
		}

		// From the Windows documentation:
		// https://learn.microsoft.com/en-us/cpp/build/exception-handling-x64?view=msvc-170
		struct DispatchContext {
			size_t pc;
			size_t base;
			RUNTIME_FUNCTION *fnEntry;
			size_t establisherFrame;
			size_t targetIp;
			CONTEXT *contextRecord;
			void *languageHandler;
			void *handlerData;
			// In later SDKs
			void *historyTable;
		};

		/**
		 * Description of a single entry in the block table.
		 */
		struct FnBlock {
			// At what offset do we start?
			Nat offset;

			// What block?
			Nat block;
		};

		SehFrame extractFrame(_EXCEPTION_RECORD *er, void *frame, _CONTEXT *ctx, void *dispatch) {
			DispatchContext *dispatchContext = (DispatchContext *)dispatch;
			size_t endOfContext = (size_t)dispatchContext->handlerData;

			// Now, endOfContext is a pointer to the byte after we stored the pointer to this
			// function. Based on that, we can find the start of the GcCode since we know how much
			// is stored after that point in the code. See Code/X64/WindowsOutput.cpp for details.
			// Note: 6 is the size of the jump operation we use as a shim to call the EH function.
			size_t endOfCode = endOfContext + 6 + sizeof(Nat);
			endOfCode = roundUp(endOfCode, sizeof(void *));

			// Now we can retrieve the GcCode! Due to the object layout, it is right after the code
			// portion (when properly aligned):
			GcCode *refs = (GcCode *)endOfCode;

			SehFrame result;
			result.stackPtr = frame;
			result.binary = code::codeBinaryImpl(refs);
			result.frameOffset = -result.binary->stackOffset();

			// We can also find the metadata table at the end of the binary:
			// The last Nat is the size of the actual code.
			Nat *ehOffset = ((Nat *)endOfCode) - 1;

			// The binary also contains the start of the code:
			size_t codeStart = size_t(result.binary->address());

			// Now, we can compute the start of the EH data and extract the block table:
			size_t startOfEhData = codeStart + *ehOffset;

			// The block table is just before the end, so we can call the generic code from there:
			Nat active = findFunctionStateFromEnd((void *)startOfEhData, dispatchContext->pc - codeStart);
			code::decodeFnState(active, result.part, result.activation);

			return result;
		}

		// Note: This is not always present in the Windows headers.
		extern "C"
		NTSYSAPI VOID RtlUnwindEx(PVOID targetFrame, PVOID targetIp, PEXCEPTION_RECORD er,
								PVOID returnValue, PCONTEXT context, void *history);


		void resumeFrame(SehFrame &frame, Binary::Resume &resume, storm::RootObject *object,
						_CONTEXT *ctx, _EXCEPTION_RECORD *er, void *dispatch) {
			DispatchContext *dispatchContext = (DispatchContext *)dispatch;

			er->ExceptionFlags |= EXCEPTION_UNWINDING;
			// Store the target block in the exception parameters!
			// Most likely, we can entirely trash the exception parameters at this stage, since
			// other code should not care during the unwind step. However, to be a bit on the safe
			// side, we just add it as the last parameter to the exception record.
			er->ExceptionInformation[er->NumberParameters++] = resume.cleanUntil;
			RtlUnwindEx(frame.stackPtr, resume.ip, er, object, ctx, dispatchContext->historyTable);

			// Expected to not return...
			dbg_assert(false, L"Failed to unwind the stack!");
		}

		void cleanupPartialFrame(SehFrame &frame, _EXCEPTION_RECORD *er) {
			size_t cleanupTo = er->ExceptionInformation[er->NumberParameters - 1];
			cleanupPartialFrame(frame, Nat(cleanupTo));
		}

	}
}

#endif

namespace code {
	namespace eh {
		using namespace code::x64;

		static const Reg order[] = {
			rax, rcx, rdx, rbx, ptrStack, ptrFrame, rsi, rdi,
			r8, r9, r10, r11, r12, r13, r14, r15
		};

		Nat win64Register(Reg reg) {
			for (Nat i = 0; i < ARRAY_COUNT(order); i++) {
				if (same(reg, order[i]))
					return i;
			}

			assert(false, L"Register not supported!");
			return 0;
		}

		Reg fromWin64Register(Nat id) {
			if (id < ARRAY_COUNT(order)) {
				return asSize(order[id], Size::sPtr);
			} else {
				return noReg;
			}
		}

	}
}
