#include "memory.h"
#include "plat.h"
#include <stdint.h>
#include <stddef.h>

int aligned(const void* p, int a) {
	return (uintptr_t)p % a == 0;
}

int align_size(int s, int a) {
	return (s + (a - 1)) & -a;
}

void zero(void* buf, int size) {
	int i;
	for (i = 0; i < size; i++)
		((uint8_t*)buf)[i] = 0;
}

uintptr_t align_address(
	uintptr_t ad,
	size_t al
) {
	size_t m;
	m = al - 1;
	return (ad + m) & ~m;
}

void init_arena(
	Arena* a,
	void* mem,
	int size
) {
	a->buf = mem;
	a->size = size;
	a->ptr = 0;
	a->last_push = 0;
}

void clear_arena(Arena* a) {
	a->ptr = 0;
	a->last_push = 0;
}

void* imp_arena_alloc(
	Arena* a,
	int size
) {
	char* r;
	assert(a->ptr + size <= a->size);
	r = &a->buf[a->ptr];
	a->ptr += size;
	return r;
}

void* arena_alloc(
	Arena* a,
	int size
) {
	return arena_alloc_aligned(
		a,
		size,
		allocation_default_alignment
	);
}

void* arena_alloc_aligned(
	Arena* a,
	int size,
	int align
) {
	void* p;
	p = imp_arena_alloc(
		a,
		size + align + 1
	);
	return (void*)align_address((uintptr_t)p, align);
}

void arena_push(Arena* a) {
	int last = a->last_push, * prev;
	a->last_push = a->ptr;
	prev = arena_alloc(a, sizeof *prev);
	*prev = last;
}

void arena_pop(Arena* a) {
	assert(a->last_push);
	a->ptr = a->last_push;
	a->last_push = *(int*)&a->buf[a->ptr];
}

void init_heap(
	Heap* h,
	void* mem,
	int size
) {
	int* fb;
	assert(aligned(mem, 4));
	assert(size > 8);
	h->buf = mem;
	h->size = size;
	h->blocks = 1;
	fb = (int*)h->buf;
	fb[0] = size;
}

void* imp2_heap_alloc(
	Heap* h,
	int size
) {
	int o, i;
	int hs = sizeof(int);
	int as = align_size(size + hs, hs);
	int f = ~((unsigned)-1 >> 1);
	for (i = o = 0; i < h->blocks; i++) {
		int* phdr = (int*)&h->buf[o];
		int hdr = *phdr, bs;
		bs = hdr & ~f;
		assert(bs);
		assert(aligned(phdr, sizeof hdr));
		if (~hdr & f) {
			if (as == bs) {
				phdr[0] |= f;
				return phdr + 1;
			} else {
				int ps = bs - as;
				int aps = align_size(ps, hs) - hs;
				if (aps > hs * 2) {
					int* nhdr = (int*)&h->buf[o + aps];
					assert(aligned(nhdr, sizeof hs));
					phdr[0] = aps;
					nhdr[0] = (as + (ps - aps)) | f;
					h->blocks++;
					return &nhdr[1];
				}
			}
		}
		o += bs;
	}
	return 0;
}

void* imp_heap_alloc(Heap* h, int s) {
	void* p = imp2_heap_alloc(h, s);
	if (!p) {
		heap_defrag(h);
		p = imp2_heap_alloc(h, s);
	}
	return p;
}

void imp_heap_free(Heap* h, void* p) {
	assert((char*)p > h->buf);
	assert((char*)p < h->buf + h->size);
	assert(((int*)p)[-1] & ~((unsigned)-1 >> 1)); /* double free */
	(void)h;
	((int*)p)[-1] &= (unsigned)-1 >> 1;
}

void* heap_alloc_aligned(
	Heap* h,
	int size,
	int align
) {
	unsigned char* p, * a;
	ptrdiff_t shift;
	size += (int)align;
	p = imp_heap_alloc(h, size);
	if (!p) { return 0; }
	a = (unsigned char*)align_address((uintptr_t)p, align);
	a += align * (unsigned)(p == a);
	shift = a - p;
	a[-1] = shift & 0xff;
	return a;
}

void heap_free_aligned(Heap* h, void* p) {
	unsigned char* a;
	ptrdiff_t shift;
	a = p;
	shift = a[-1];
	shift += 256 * shift == 0;
	a -= shift;
	imp_heap_free(h, a);
}

void heap_defrag(Heap* h) {
	int i, o, mtc;
	int f = ~((unsigned)-1 >> 1);
	for (i = o = mtc = 0; i < h->blocks; i++) {
		int* phdr = (int*)&h->buf[o];
		int hdr = *phdr, bs, m, mc;
		assert(aligned(phdr, sizeof hdr));
		bs = hdr & ~f;
		if (~hdr & f) {
			for (
				m = bs, mc = 0, i++;
				i < h->blocks;
				i++, mc++
			) {
				int mhdr = *(int*)&h->buf[o + m];
				if (~mhdr & f)
					m += mhdr & ~f;
				else
					break;
			}
			i--;
			bs = m;
			phdr[0] = bs;
			mtc += mc;
		}
		o += bs;
	}
	h->blocks -= mtc;
}

void* heap_alloc(
	Heap* h,
	int size
) {
	return heap_alloc_aligned(
		h,
		size,
		allocation_default_alignment
	);
}

void heap_free(
	Heap* h,
	void* p
) {
	heap_free_aligned(h, p);
}

int heap_block_size(void* ptr) {
	unsigned char* a;
	int f = ~((unsigned)-1 >> 1);
	ptrdiff_t shift;
	a = ptr;
	shift = a[-1];
	shift += 256 * shift == 0;
	a -= shift;
	return ((int*)a)[-1] & ~f;
}

/*
void print_blocks(Heap* h) {
	int i, o;
	int fb = ~((unsigned)-1 >> 1);
	for (i = o = 0; i < h->blocks; i++) {
		int b = *(int*)&h->buf[o];
		int bs = b & ~fb;
		int f = ~b & fb;
		printf("%s  %d\n", f? "free": "    ", bs);
		o += bs;
	}
	assert(o == h->size);
}
*/