#include <math.h>
#include <stdio.h>
#include <stdlib.h>

#include "memory.h"
#include "plat.h"
#include "sc/sh_enums.h"
#include "sc/sh_helpers.h"
#include "str.h"

#define CGLTF_IMPLEMENTATION
#include "cgltf.h"

typedef struct {
	const char* shader;
	const char* depth_shader;
	const char* material;
} Node_Config;

typedef struct Shader_Attrib {
	cgltf_attribute_type target;
	SVariable_Type type;
	int offset;
	int size;
	struct Shader_Attrib* next;
} Shader_Attrib;

Arena arena;
const char* shader_dir;
const char* shader_dir_sep;

char* vertex_buffer;
int vertex_buffer_size;
uint16_t* index_buffer;
int index_buffer_size;
int vertex_count, index_count;
int vertex_size;
int full_vert_size = 0;
int full_ind_size = 0;
float* current_vertex;
const cgltf_accessor* accessors[cgltf_attribute_type_max_enum];
Shader_Attrib* target_attribs[cgltf_attribute_type_max_enum];
float min_bound[3], max_bound[3];

int tcmp(const char* src, const char* s, const jsmntok_t* t) {
	int l = t->end - t->start;
	if (t->type != JSMN_STRING) return 0;
	return !strncmp(src + t->start, s, l);
}

const char* read_str(const char* src, const jsmntok_t* t) {
	int l = t->end - t->start, i;
	char* buf = arena_alloc(&arena, l + 1);
	for (i = 0; i < l; i++)
		buf[i] = src[i + t->start];
	buf[i] = 0;
	return buf;
}

void parse_node_cfg(const cgltf_node* n, Node_Config* cfg) {
	int i, c;
	const char* json = n->extras.data;
	jsmn_parser p;
	jsmntok_t toks[32];
	jsmn_init(&p);
	c = jsmn_parse(
		&p,
		json,
		strlen(json),
		toks,
		sizeof toks / sizeof *toks
	);
	if (c == 0) return;
	if (c < 0) {
		print_err("Invalid extras json (or too big).\n");
		pbreak(100);
	}
	for (i = 0; i < c; i++) {
		const jsmntok_t* t = &toks[i];
		if (tcmp(json, "shader", t)) {
			cfg->shader = read_str(json, &t[1]);
			if (string_len(cfg->shader) > 27) {
				print_err(
					"Shader name %s too long (max 27 chars).\n",
					cfg->shader
				);
				pbreak(3478);
			}
			i++;
		}
		if (tcmp(json, "depth_shader", t)) {
			cfg->depth_shader = read_str(json, &t[1]);
			if (string_len(cfg->depth_shader) > 27) {
				print_err(
					"depth_shader name %s too long (max 27 chars).\n",
					cfg->depth_shader
				);
				pbreak(3478);
			}
			i++;
		}
		if (tcmp(json, "material", t)) {
			cfg->material = read_str(json, &t[1]);
			if (string_len(cfg->material) > 27) {
				print_err(
					"Material name %s too long (max 27 chars).\n",
					cfg->material
				);
				pbreak(3479);
			}
			i++;
		}
	}
}

void build_attrib_accessors(
	const cgltf_primitive* p,
	Shader_Attrib* desired
) {
	int i, c = cgltf_attribute_type_max_enum;
	Shader_Attrib* attrib = desired;
	for (i = 0; i < c; i++) {
		target_attribs[i] = 0;
		accessors[i] = 0;
	}
	c = p->attributes_count;
	for (; attrib; attrib = attrib->next) {
		for (i = 0; i < c; i++) {
			const cgltf_attribute* a = &p->attributes[i];
			if (attrib->target == a->type) {
				accessors[attrib->target] = a->data;
				break;
			}
		}
		target_attribs[attrib->target] = attrib;
	}
}

void read_vertex(
	float* dst,
	uint16_t idx
) {
	int i, c = cgltf_attribute_type_max_enum;
	for (i = 0; i < c; i++) {
		const cgltf_accessor* a = accessors[i];
		Shader_Attrib* t = target_attribs[i];
		if (!t) continue;
		int ec = t->size / 4;
		if (a) {
			if (a->component_type != cgltf_component_type_r_32f) {
				print_err("Only float attributes are supported.\n");
				pbreak(33);
			}
			const cgltf_buffer_view* v = a->buffer_view;
			int j, off = t->offset / 4;
			int sec = (a->stride / 4);
			if (ec >= sec) {
				for (j = 0; j < sec; j++)
					dst[j + off] = ((float*)&((char*)v->buffer->data)[v->offset])
						[idx * sec + j];
				for (; j < ec; j++)
					dst[j + off] = 0.0f;
			} else {
				for (j = 0; j < ec; j++)
					dst[j + off] = ((float*)&((char*)v->buffer->data)[v->offset])
						[idx * sec + j];
			}
		} else
			memset(&dst[t->offset / 4], 0, t->size);
	}
}

void add_index(uint16_t i) {
	int s = sizeof i;
	if (index_count * s >= index_buffer_size) {
		index_buffer_size *= 2;
		index_buffer = realloc(index_buffer, index_buffer_size);
	}
	index_buffer[index_count++] = i;
	full_ind_size += sizeof *index_buffer;
}

int about_eq(float* a, float* b) {
	int i, ec = vertex_size / 4;
	for (i = 0; i < ec; i++) {
		float v = a[i] - b[i];
		if (v > 0.0001f || v < -0.0001f)
			return 0;
	}
	return 1;
}

void add_vertex(float* vertex) {
	int i = 0;
	for (i = 0; i < vertex_count; i++) {
		float* check = (float*)&vertex_buffer[vertex_size * i];
		if (about_eq(vertex, check)) {
			add_index((uint16_t)i);
			return;
		}
	}
	if (
		vertex_count * vertex_size + vertex_size >= vertex_buffer_size
	) {
		vertex_buffer_size *= 2;
		vertex_buffer = realloc(vertex_buffer, vertex_buffer_size);
	}
	memcpy(
		&vertex_buffer[vertex_size * vertex_count],
		vertex,
		vertex_size
	);
	add_index(vertex_count++);
	full_vert_size += vertex_size;
}

void parse_prim(
	const cgltf_primitive* p,
	Shader_Attrib* desired
) {
	int i, c;
	const cgltf_accessor* a, * pa;
	const cgltf_buffer_view* v;
	uint16_t* indices;
	if (p->type != cgltf_primitive_type_triangles) {
		print_err("Only triangle meshes are supported.\n");
		pbreak(30);
	}
	a = p->indices;
	if (!a) {
		print_err("Meshes must be indexed.\n");
		pbreak(31);
	}
	if (a->stride != sizeof *index_buffer) {
		print_err(
			"Only %d byte indices are supported.\n",
			sizeof *index_buffer
		);
		pbreak(32);
	}
	v = a->buffer_view;
	c = a->count;
	indices = (uint16_t*)&((char*)v->buffer->data)[v->offset];
	build_attrib_accessors(p, desired);
	pa = accessors[cgltf_attribute_type_position];
	if (
		pa &&
		pa->type == cgltf_type_vec3 &&
		pa->component_type == cgltf_component_type_r_32f
	) {
		char* data = pa->buffer_view->buffer->data;
		int s = pa->stride;
		data += pa->buffer_view->offset;
		for (i = 0; i < c; i++) {
			float* pos = (float*)&data[indices[i] * s];
			if (pos[0] < min_bound[0]) min_bound[0] = pos[0];
			if (pos[1] < min_bound[1]) min_bound[1] = pos[1];
			if (pos[2] < min_bound[2]) min_bound[2] = pos[2];
			if (pos[0] > max_bound[0]) max_bound[0] = pos[0];
			if (pos[1] > max_bound[1]) max_bound[1] = pos[1];
			if (pos[2] > max_bound[2]) max_bound[2] = pos[2];
		}
	}
	for (i = 0; i < c; i++) {
		read_vertex(current_vertex, indices[i]);
		add_vertex(current_vertex);
	}
}

void parse_node_mesh(
	int parent_index,
	const cgltf_node* n,
	const Node_Config* cfg,
	Shader_Attrib* desired,
	FILE* outfile
) {
	const cgltf_mesh* m = n->mesh;
	int i, c = m->primitives_count;
	char buf[28];
	float matrix[16];
	vertex_count = 0;
	index_count = 0;
	min_bound[0] =  INFINITY;
	min_bound[1] =  INFINITY;
	min_bound[2] =  INFINITY;
	max_bound[0] = -INFINITY;
	max_bound[1] = -INFINITY;
	max_bound[2] = -INFINITY;
	for (i = 0; i < c; i++)
		parse_prim(&m->primitives[i], desired);
	cgltf_node_transform_local(n, matrix);
	fwrite("MESH", 4, 1, outfile);
	zero(buf, sizeof buf);
	string_copy(buf, cfg->shader);
	fwrite(buf, 1, sizeof buf, outfile);
	zero(buf, sizeof buf);
	string_copy(buf, cfg->depth_shader);
	fwrite(buf, 1, sizeof buf, outfile);
	zero(buf, sizeof buf);
	string_copy(buf, cfg->material);
	fwrite(buf, 1, sizeof buf, outfile);
	fwrite(&vertex_size, 4, 1, outfile);
	fwrite(&index_count, 4, 1, outfile);
	fwrite(&vertex_count, 4, 1, outfile);
	fwrite(&parent_index, 4, 1, outfile);
	fwrite(matrix, sizeof matrix, 1, outfile);
	fwrite(min_bound, sizeof min_bound, 1, outfile);
	fwrite(max_bound, sizeof max_bound, 1, outfile);
	fwrite(vertex_buffer, 1, vertex_count * vertex_size, outfile);
	fwrite(index_buffer, 1, index_count * sizeof *index_buffer, outfile);
}

cgltf_attribute_type get_attribute_type(const char* name) {
	if (string_equal(name, "position"))
		return cgltf_attribute_type_position;
	if (string_equal(name, "normal"))
		return cgltf_attribute_type_normal;
	if (string_equal(name, "tangent"))
		return cgltf_attribute_type_tangent;
	if (string_equal(name, "uv"))
		return cgltf_attribute_type_texcoord;
	if (string_equal(name, "colour"))
		return cgltf_attribute_type_color;
	if (string_equal(name, "joints"))
		return cgltf_attribute_type_joints;
	if (string_equal(name, "weights"))
		return cgltf_attribute_type_weights;
	return cgltf_attribute_type_invalid;
}

Shader_Attrib* parse_shader_attribs(const char* fname) {
	FILE* f;
	char magic[4];
	int type, i;
	int binding_count, target_count, desc_count, opt_count;
	char* fpath = arena_alloc(
		&arena,
		string_len(fname) +
		string_len(shader_dir) +
		string_len(shader_dir_sep) + 1
	);
	strcpy(fpath, shader_dir);
	strcat(fpath, shader_dir_sep);
	strcat(fpath, fname);
	f = fopen(fpath, "rb");
	if (!f) {
		print_err("Failed to open shader %s.\n", fpath);
		pbreak(303);
	}
	fread(magic, 4, 1, f);
	if (
		magic[0] != 'C' ||
		magic[1] != 'S' ||
		magic[2] != 'H' ||
		magic[3] != '2'
	) {
		print_err("%s - invalid shader.\n", fname);
		pbreak(304);
	}
	fread(&type, 4, 1, f);
	assert(type == sprogram_type_graphics);
	fread(&binding_count, 4, 1, f);
	fread(&target_count, 4, 1, f);
	fread(&desc_count, 4, 1, f);
	fread(&opt_count, 4, 1, f);
	assert(binding_count);
	for (i = 0; i < binding_count; i++) {
		char name[24];
		int rate, count;
		fread(name, 1, sizeof name, f);
		fread(&rate, 4, 1, f);
		fread(&count, 4, 1, f);
		if (string_equal(name, "mesh")) {
			int j, coff = 0;
			Shader_Attrib* r = 0;
			for (j = 0; j < count; j++) {
				char aname[28];
				SVariable_Type type;
				cgltf_attribute_type target;
				Shader_Attrib* attrib;
				fread(aname, 1, sizeof aname, f);
				fread(&type, 4, 1, f);
				target = get_attribute_type(aname);
				if (target == cgltf_attribute_type_invalid) {
					print_err("%s is not a valid mesh attribute.\n", aname);
					pbreak(305);
				}
				attrib = arena_alloc(&arena, sizeof *attrib);
				attrib->target = target;
				attrib->type = type;
				attrib->offset = coff;
				attrib->size = svariable_type_size(type);
				attrib->next = r;
				r = attrib;
				coff += attrib->size;
			}
			fclose(f);
			return r;
		} else {
			fseek(f, 32 * count, SEEK_CUR);
		}
	}
	print_err("Shader %s has no mesh vertex binding.\n", fname);
	pbreak(306);
	fclose(f);
	return 0;
}

int calc_vertex_size(Shader_Attrib* attribs) {
	int s = 0;
	for (; attribs; attribs = attribs->next)
		s += attribs->size;
	return s;
}

void parse_node(
	int parent_index,
	const cgltf_node* n,
	FILE* outfile
) {
	Node_Config cfg = { 0 };
	if (n->extras.data)
		parse_node_cfg(n, &cfg);
	if (n->mesh) {
		Shader_Attrib* desired;
		if (!cfg.shader)
			cfg.shader = "surface.csh";
		if (!cfg.depth_shader)
			cfg.depth_shader = "surface_depthonly.csh";
		if (!cfg.material) {
			print_err(
				"Node %s has a mesh, but doesn't specify a material.\n",
				n->name
			);
			pbreak(49);
		}
		desired = parse_shader_attribs(cfg.shader);
		vertex_size = calc_vertex_size(desired);
		current_vertex = arena_alloc(&arena, vertex_size);
		assert(desired != 0);
		parse_node_mesh(parent_index, n, &cfg, desired, outfile);
	}
}

void write_header(FILE* outfile) {
	int z = 0;
	fwrite("MODL", 4, 1, outfile);
	fwrite(&z, 4, 1, outfile);
	fwrite(&z, 4, 1, outfile);
}

int parse(const char* fname, FILE* outfile) {
	int i, c;
	cgltf_options opt = { 0 };
	cgltf_data* d;
	cgltf_result r;
	r = cgltf_parse_file(&opt, fname, &d);
	if (r) return r;
	r = cgltf_load_buffers(&opt, d, fname);
	if (r) return r;
	c = (int)d->nodes_count;
	fwrite(&c, 4, 1, outfile);
	for (i = 0; i < c; i++) {
		const cgltf_node* n = &d->nodes[i];
		int parent_index = 0;
		if (n->parent)
			parent_index = n->parent - d->nodes;
		clear_arena(&arena);
		parse_node(parent_index, &d->nodes[i], outfile);
	}
	cgltf_free(d);
	return 0;
}

int main(int argc, const char** argv) {
	int r;
	int mem_size = 1024 * 1024;
	void* mem;
	FILE* outfile;
	if (argc < 4) {
		print_err("Usage: %s shader_dir infile outfile\n", argv[0]);
		return 0;
	}
	vertex_buffer_size = 1024;
	index_buffer_size = 1024;
	vertex_buffer = malloc(vertex_buffer_size);
	index_buffer = malloc(index_buffer_size);
	outfile = fopen(argv[3], "wb");
	shader_dir = argv[1];
	shader_dir_sep = "";
	if (shader_dir[string_len(shader_dir) - 1] != '/')
		shader_dir_sep = "/";
	mem = malloc(mem_size);
	init_arena(&arena, mem, mem_size);
	write_header(outfile);
	r = parse(argv[2], outfile);
	if (r) {
		print_err("Parse or file error.\n");
		fclose(outfile);
		return r;
	}
	fseek(outfile, 4, SEEK_SET);
	fwrite(&full_vert_size, 4, 1, outfile);
	fwrite(&full_ind_size, 4, 1, outfile);
	fclose(outfile);
	return 0;
}