#include "maths.hpp"

#include <algorithm>

void AABB::encompass(const AABB& other) {
	min.x = std::min(min.x, other.min.x);
	min.y = std::min(min.y, other.min.y);
	min.z = std::min(min.z, other.min.z);
	max.x = std::max(max.x, other.max.x);
	max.y = std::max(max.y, other.max.y);
	max.z = std::max(max.z, other.max.z);
}

namespace quat {
	v4f identity() {
		return v4f(0.0f, 0.0f, 0.0f, 1.0f);
	}

	v4f scale(const v4f& q, float s) {
		return v4f(q.xyz() * s, q.w);
	}

	v4f normalised(const v4f& q) {
		float l = v4f::mag(q);
		return v4f(q.xyz() / l, q.w);
	}

	v4f conjugate(const v4f& q) {
		return v4f(-q.xyz(), q.w);
	}

	template <>
	v4f mul<false>(const v4f& a, const v4f& b) {
		return v4f(
		 	 a.x * b.w +
		 	 a.y * b.z -
		 	 a.z * b.y +
		 	 a.w * b.x,

			-a.x * b.z +
		 	 a.y * b.w +
		 	 a.z * b.x +
		 	 a.w * b.y,

		 	 a.x * b.y -
		 	 a.y * b.x +
		 	 a.z * b.w +
		 	 a.w * b.z,

			-a.x * b.x -
		 	 a.y * b.y -
		 	 a.z * b.z +
		 	 a.w * b.w
		);
	}

	template <>
	v4f mul<true>(const v4f& a, const v4f& b) {
		return normalised(mul<false>(a, b));
	}

	v4f mul(const v4f& a, const v4f& b) {
		return mul<true>(a, b);
	}

	v4f rotate(float angle, const v3f& axis) {
		float h = to_rad(angle) * 0.5f;
		float s = sinf(h);
		return v4f(
			cosf(h),
			s * axis.x,
			s * axis.y,
			s * axis.z
		);
	}

	v4f euler(const v3f& a) {
		return mul(mul(
			rotate(a.x, v3f(1.0f, 0.0f, 0.0f)),
			rotate(a.y, v3f(0.0f, 1.0f, 0.0f))),
			rotate(a.z, v3f(0.0f, 0.0f, 1.0f))
		);
	}
}

m4f::m4f() {}

m4f::m4f(float d) {
	for (int y = 0; y < 4; y++) {
		for (int x = 0; x < 4; x++) {
			m[x][y] = 0.0f;
		}
	}

	m[0][0] = d;
	m[1][1] = d;
	m[2][2] = d;
	m[3][3] = d;
}

m4f::m4f(const v4f& r0, const v4f& r1, const v4f& r2, const v4f& r3) {
	m[0][0] = r0.x;
	m[0][1] = r0.y;
	m[0][2] = r0.z;
	m[0][3] = r0.w;
	m[1][0] = r1.x;
	m[1][1] = r1.y;
	m[1][2] = r1.z;
	m[1][3] = r1.w;
	m[2][0] = r2.x;
	m[2][1] = r2.y;
	m[2][2] = r2.z;
	m[2][3] = r2.w;
	m[3][0] = r3.x;
	m[3][1] = r3.y;
	m[3][2] = r3.z;
	m[3][3] = r3.w;
}

m4f m4f::identity() {
	return m4f(1.0f);
}

m4f m4f::screenspace(float hw, float hh) {
	m4f r(1.0f);

	r.m[0][0] =  hw;
	r.m[0][3] =  hw;
	r.m[1][1] = -hh;
	r.m[1][3] =  hh;

	return r;
}

m4f m4f::operator*(const m4f& other) const {
	m4f r(1.0f);

	r.m[0][0] = m[0][0] * other.m[0][0] + m[1][0] * other.m[0][1] + m[2][0] * other.m[0][2] + m[3][0] * other.m[0][3];
	r.m[1][0] = m[0][0] * other.m[1][0] + m[1][0] * other.m[1][1] + m[2][0] * other.m[1][2] + m[3][0] * other.m[1][3];
	r.m[2][0] = m[0][0] * other.m[2][0] + m[1][0] * other.m[2][1] + m[2][0] * other.m[2][2] + m[3][0] * other.m[2][3];
	r.m[3][0] = m[0][0] * other.m[3][0] + m[1][0] * other.m[3][1] + m[2][0] * other.m[3][2] + m[3][0] * other.m[3][3];
	r.m[0][1] = m[0][1] * other.m[0][0] + m[1][1] * other.m[0][1] + m[2][1] * other.m[0][2] + m[3][1] * other.m[0][3];
	r.m[1][1] = m[0][1] * other.m[1][0] + m[1][1] * other.m[1][1] + m[2][1] * other.m[1][2] + m[3][1] * other.m[1][3];
	r.m[2][1] = m[0][1] * other.m[2][0] + m[1][1] * other.m[2][1] + m[2][1] * other.m[2][2] + m[3][1] * other.m[2][3];
	r.m[3][1] = m[0][1] * other.m[3][0] + m[1][1] * other.m[3][1] + m[2][1] * other.m[3][2] + m[3][1] * other.m[3][3];
	r.m[0][2] = m[0][2] * other.m[0][0] + m[1][2] * other.m[0][1] + m[2][2] * other.m[0][2] + m[3][2] * other.m[0][3];
	r.m[1][2] = m[0][2] * other.m[1][0] + m[1][2] * other.m[1][1] + m[2][2] * other.m[1][2] + m[3][2] * other.m[1][3];
	r.m[2][2] = m[0][2] * other.m[2][0] + m[1][2] * other.m[2][1] + m[2][2] * other.m[2][2] + m[3][2] * other.m[2][3];
	r.m[3][2] = m[0][2] * other.m[3][0] + m[1][2] * other.m[3][1] + m[2][2] * other.m[3][2] + m[3][2] * other.m[3][3];
	r.m[0][3] = m[0][3] * other.m[0][0] + m[1][3] * other.m[0][1] + m[2][3] * other.m[0][2] + m[3][3] * other.m[0][3];
	r.m[1][3] = m[0][3] * other.m[1][0] + m[1][3] * other.m[1][1] + m[2][3] * other.m[1][2] + m[3][3] * other.m[1][3];
	r.m[2][3] = m[0][3] * other.m[2][0] + m[1][3] * other.m[2][1] + m[2][3] * other.m[2][2] + m[3][3] * other.m[2][3];
	r.m[3][3] = m[0][3] * other.m[3][0] + m[1][3] * other.m[3][1] + m[2][3] * other.m[3][2] + m[3][3] * other.m[3][3];

	return r;
}

v4f m4f::operator*(const v4f& other) const {
	return v4f(
		m[0][0] * other.x + m[1][0] * other.y + m[2][0] * other.z + m[3][0] * other.w,
		m[0][1] * other.x + m[1][1] * other.y + m[2][1] * other.z + m[3][1] * other.w,
		m[0][2] * other.x + m[1][2] * other.y + m[2][2] * other.z + m[3][2] * other.w,
		m[0][3] * other.x + m[1][3] * other.y + m[2][3] * other.z + m[3][3] * other.w
	);
}

m4f m4f::translate(m4f m, v3f v) {
	m4f r(1.0f);

	r.m[3][0] += v.x;
	r.m[3][1] += v.y;
	r.m[3][2] += v.z;

	return m * r;
}

m4f m4f::rotate(m4f m, float a, v3f v) {
	m4f r(1.0f);

	const float c = cosf(a);
	const float s = sinf(a);

	const float omc = (float)1 - c;

	const float x = v.x;
	const float y = v.y;
	const float z = v.z;

	r.m[0][0] = x * x * omc + c;
	r.m[0][1] = y * x * omc + z * s;
	r.m[0][2] = x * z * omc - y * s;
	r.m[1][0] = x * y * omc - z * s;
	r.m[1][1] = y * y * omc + c;
	r.m[1][2] = y * z * omc + x * s;
	r.m[2][0] = x * z * omc + y * s;
	r.m[2][1] = y * z * omc - x * s;
	r.m[2][2] = z * z * omc + c;

	return m * r;
}

m4f m4f::rotate(m4f r, const v4f& q) {
	float qx, qy, qz, qw, qx2, qy2, qz2, qxqx2, qyqy2, qzqz2, qxqy2, qyqz2, qzqw2, qxqz2, qyqw2, qxqw2;
	qx = -q.x;
	qy = -q.y;
	qz = -q.z;
	qw = q.w;
	qx2 = (qx + qx);
	qy2 = (qy + qy);
	qz2 = (qz + qz);
	qxqx2 = (qx * qx2);
	qxqy2 = (qx * qy2);
	qxqz2 = (qx * qz2);
	qxqw2 = (qw * qx2);
	qyqy2 = (qy * qy2);
	qyqz2 = (qy * qz2);
	qyqw2 = (qw * qy2);
	qzqz2 = (qz * qz2);
	qzqw2 = (qw * qz2);
	r.m[0][0] = ((1.0f - qyqy2) - qzqz2);
	r.m[0][1] = qxqy2 - qzqw2;
	r.m[0][2] = qxqz2 + qyqw2;
	r.m[0][3] = 0.0f;
	r.m[1][0] = qxqy2 + qzqw2;
	r.m[1][1] = (1.0f - qxqx2) - qzqz2;
	r.m[1][2] = qyqz2 - qxqw2;
	r.m[1][3] = 0.0f;
	r.m[2][0] = qxqz2 - qyqw2;
	r.m[2][1] = qyqz2 + qxqw2;
	r.m[2][2] = (1.0f - qxqx2) - qyqy2;
	r.m[2][3] = 0.0f;
	return r;
}

m4f m4f::scale(m4f m, v3f v) {
	m4f r(1.0f);

	r.m[0][0] = v.x;
	r.m[1][1] = v.y;
	r.m[2][2] = v.z;

	return m * r;
}


v3f m4f::get_translation() {
	return v3f(m[3][0], m[3][1], m[3][2]);
}

m4f m4f::lookat(v3f c, v3f o, v3f u) {
	m4f r(1.0f);

	const v3f f = v3f::normalised(o - c);
	u = v3f::normalised(u);
	const v3f s = v3f::normalised(v3f::cross(f, u));
	u = v3f::cross(s, f);

	r.m[0][0] = s.x;
	r.m[1][0] = s.y;
	r.m[2][0] = s.z;
	r.m[0][1] = u.x;
	r.m[1][1] = u.y;
	r.m[2][1] = u.z;
	r.m[0][2] = -f.x;
	r.m[1][2] = -f.y;
	r.m[2][2] = -f.z;
	r.m[3][0] = -v3f::dot(s, c);
	r.m[3][1] = -v3f::dot(u, c);
	r.m[3][2] = v3f::dot(f, c);

	return r;
}

m4f m4f::pers(float fov, float asp, float n, float f) {
	m4f r(0.0f);

	float fov_rad = to_rad(fov);
	float focal_length = 1.0f / tanf(fov_rad / 2.0f);

	float x =  focal_length / asp;
	float y = -focal_length;
	float a = f / (n - f);
	float b = n * a;

	r.m[0][0] = x;
	r.m[1][1] = y;
	r.m[2][2] = a;
	r.m[3][2] = b;
	r.m[2][3] = -1.0f;

	return r;
}

m4f m4f::orth(float l, float r, float b, float t, float n, float f) {
	m4f res(1.0f);

	res.m[0][0] = 2.0f / (r - l);
	res.m[1][1] = 2.0f / (b - t);
	res.m[2][2] = 1.0f / (n - f);
	res.m[3][0] = -(l + r) / (r - l);
	res.m[3][1] = -(b + t) / (b - t);
	res.m[3][2] = n / (n - f);

	return res;
}

m4f m4f::inverse() const {
	const float* mm = (float*)m;
	m4f r;
	float t0 = mm[10] * mm[15];
	float t1 = mm[14] * mm[11];
	float t2 = mm[6] * mm[15];
	float t3 = mm[14] * mm[7];
	float t4 = mm[6] * mm[11];
	float t5 = mm[10] * mm[7];
	float t6 = mm[2] * mm[15];
	float t7 = mm[14] * mm[3];
	float t8 = mm[2] * mm[11];
	float t9 = mm[10] * mm[3];
	float t10 = mm[2] * mm[7];
	float t11 = mm[6] * mm[3];
	float t12 = mm[8] * mm[13];
	float t13 = mm[12] * mm[9];
	float t14 = mm[4] * mm[13];
	float t15 = mm[12] * mm[5];
	float t16 = mm[4] * mm[9];
	float t17 = mm[8] * mm[5];
	float t18 = mm[0] * mm[13];
	float t19 = mm[12] * mm[1];
	float t20 = mm[0] * mm[9];
	float t21 = mm[8] * mm[1];
	float t22 = mm[0] * mm[5];
	float t23 = mm[4] * mm[1];
	float* o = (float*)r.m;
	o[0] = (t0 * mm[5] + t3 * mm[9] + t4 * mm[13]) - (t1 * mm[5] + t2 * mm[9] + t5 * mm[13]);
	o[1] = (t1 * mm[1] + t6 * mm[9] + t9 * mm[13]) - (t0 * mm[1] + t7 * mm[9] + t8 * mm[13]);
	o[2] = (t2 * mm[1] + t7 * mm[5] + t10 * mm[13]) - (t3 * mm[1] + t6 * mm[5] + t11 * mm[13]);
	o[3] = (t5 * mm[1] + t8 * mm[5] + t11 * mm[9]) - (t4 * mm[1] + t9 * mm[5] + t10 * mm[9]);
	float d = 1.0f / (mm[0] * o[0] + mm[4] * o[1] + mm[8] * o[2] + mm[12] * o[3]);
	o[0] = d * o[0];
	o[1] = d * o[1];
	o[2] = d * o[2];
	o[3] = d * o[3];
	o[4] = d * ((t1 * mm[4] + t2 * mm[8] + t5 * mm[12]) - (t0 * mm[4] + t3 * mm[8] + t4 * mm[12]));
	o[5] = d * ((t0 * mm[0] + t7 * mm[8] + t8 * mm[12]) - (t1 * mm[0] + t6 * mm[8] + t9 * mm[12]));
	o[6] = d * ((t3 * mm[0] + t6 * mm[4] + t11 * mm[12]) - (t2 * mm[0] + t7 * mm[4] + t10 * mm[12]));
	o[7] = d * ((t4 * mm[0] + t9 * mm[4] + t10 * mm[8]) - (t5 * mm[0] + t8 * mm[4] + t11 * mm[8]));
	o[8] = d * ((t12 * mm[7] + t15 * mm[11] + t16 * mm[15]) - (t13 * mm[7] + t14 * mm[11] + t17 * mm[15]));
	o[9] = d * ((t13 * mm[3] + t18 * mm[11] + t21 * mm[15]) - (t12 * mm[3] + t19 * mm[11] + t20 * mm[15]));
	o[10] = d * ((t14 * mm[3] + t19 * mm[7] + t22 * mm[15]) - (t15 * mm[3] + t18 * mm[7] + t23 * mm[15]));
	o[11] = d * ((t17 * mm[3] + t20 * mm[7] + t23 * mm[11]) - (t16 * mm[3] + t21 * mm[7] + t22 * mm[11]));
	o[12] = d * ((t14 * mm[10] + t17 * mm[14] + t13 * mm[6]) - (t16 * mm[14] + t12 * mm[6] + t15 * mm[10]));
	o[13] = d * ((t20 * mm[14] + t12 * mm[2] + t19 * mm[10]) - (t18 * mm[10] + t21 * mm[14] + t13 * mm[2]));
	o[14] = d * ((t18 * mm[6] + t23 * mm[14] + t15 * mm[2]) - (t22 * mm[14] + t14 * mm[2] + t19 * mm[6]));
	o[15] = d * ((t22 * mm[10] + t16 * mm[2] + t21 * mm[6]) - (t20 * mm[6] + t23 * mm[10] + t17 * mm[2]));
	return r;
}

m4f m4f::transposed() const {
	m4f r(1.0f);
	r.m[0][0] = m[0][0];
	r.m[1][0] = m[0][1];
	r.m[2][0] = m[0][2];
	r.m[3][0] = m[0][3];
	r.m[0][1] = m[1][0];
	r.m[1][1] = m[1][1];
	r.m[2][1] = m[1][2];
	r.m[3][1] = m[1][3];
	r.m[0][2] = m[2][0];
	r.m[1][2] = m[2][1];
	r.m[2][2] = m[2][2];
	r.m[3][2] = m[2][3];
	r.m[0][3] = m[3][0];
	r.m[1][3] = m[3][1];
	r.m[2][3] = m[3][2];
	r.m[2][3] = m[3][3];
	return r;
}

v4f m4f::transform(const m4f& m, const v4f& v) {
	return m * v;
}

AABB m4f::transform(const m4f& m, const AABB& aabb) {
	v3f corners[] = {
		aabb.min,
		v3f(aabb.min.x, aabb.max.y, aabb.min.z),
		v3f(aabb.min.x, aabb.max.y, aabb.max.z),
		v3f(aabb.min.x, aabb.min.y, aabb.max.z),
		v3f(aabb.max.x, aabb.min.y, aabb.min.z),
		v3f(aabb.max.x, aabb.max.y, aabb.min.z),
		aabb.max,
		v3f(aabb.max.x, aabb.min.y, aabb.max.z)
	};

	AABB result = {
		.min = { INFINITY, INFINITY, INFINITY },
		.max = { -INFINITY, -INFINITY, -INFINITY }
	};

	for (int i = 0; i < 8; i++) {
		v4f point = m4f::transform(m, v4f(corners[i].x, corners[i].y, corners[i].z, 1.0f));

		result.min.x = std::min(result.min.x, point.x);
		result.min.y = std::min(result.min.y, point.y);
		result.min.z = std::min(result.min.z, point.z);
		result.max.x = std::max(result.max.x, point.x);
		result.max.y = std::max(result.max.y, point.y);
		result.max.z = std::max(result.max.z, point.z);
	}

	return result;
}

m3f::m3f() {}
m3f::m3f(float d) {
	for (int y = 0; y < 3; y++) {
		for (int x = 0; x < 3; x++) {
			m[x][y] = 0.0f;
		}
	}
	m[0][0] = d;
	m[1][1] = d;
	m[2][2] = d;
	m[3][3] = d;
}

m3f::m3f(const v3f& r0, const v3f& r1, const v3f& r2) {
	m[0][0] = r0.x;
	m[0][1] = r0.y;
	m[0][2] = r0.z;
	m[1][0] = r1.x;
	m[1][1] = r1.y;
	m[1][2] = r1.z;
	m[2][0] = r2.x;
	m[2][1] = r2.y;
	m[2][2] = r2.z;
}

v3f m3f::transform(const m3f& m, const v3f& v) {
	return m * v;
}

v3f m3f::operator*(const v3f& other) const {
	return v3f(
		m[0][0] * other.x + m[1][0] * other.y + m[2][0] * other.z,
		m[0][1] * other.x + m[1][1] * other.y + m[2][1] * other.z,
		m[0][2] * other.x + m[1][2] * other.y + m[2][2] * other.z
	);
}

m3f m3f::inverse() const {
	m3f r;
	r.m[0][0] = m[1][1] * m[2][2] - m[1][2] * m[2][1];
	r.m[0][1] = m[0][2] * m[2][1] - m[0][1] * m[2][2];
	r.m[0][2] = m[0][1] * m[1][2] - m[0][2] * m[1][1];
	r.m[1][0] = m[1][2] * m[2][0] - m[1][0] * m[2][2];
	r.m[1][1] = m[0][0] * m[2][2] - m[0][2] * m[2][0];
	r.m[1][2] = m[0][2] * m[1][0] - m[0][0] * m[1][2];
	r.m[2][0] = m[1][0] * m[2][1] - m[1][1] * m[2][0];
	r.m[2][1] = m[0][1] * m[2][0] - m[0][0] * m[2][1];
	r.m[2][2] = m[0][0] * m[1][1] - m[0][1] * m[1][0];
	float d =
		m[0][0] * r.m[0][0] +
		m[0][1] * r.m[1][0] +
		m[0][2] * r.m[2][0];
	d = 1.0f / d;
	r.m[0][0] *= d;
	r.m[0][1] *= d;
	r.m[0][2] *= d;
	r.m[1][0] *= d;
	r.m[1][1] *= d;
	r.m[1][2] *= d;
	r.m[2][0] *= d;
	r.m[2][1] *= d;
	r.m[2][2] *= d;
	return r;
}