﻿using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using UnityEditor;
using System;

namespace BS
{
    public class NodalPathfinding2D : MonoBehaviour
    {
        public Bounds bounds;
        [Range(0.1f, 1f)]
        public float pointInterval = 0.5f;
		public float agentRadius = 0.5f;
        
        public LayerMask blockMask;
        private Dictionary<Vector2Int, Node> nodes = new Dictionary<Vector2Int, Node>();

        public bool isBaked { get { return nodes.Count > 0; } }

		public bool isOptimizing = true;
		[Range(2,10)]
		public int complementLevel = 2;

		public bool enableGizmos = true;

#if UNITY_EDITOR

		private void OnDrawGizmos()
        {
            if (!enableGizmos)
                return;


			Gizmos.color = Color.red;
			Vector3 mousePosition = Event.current.mousePosition;
			Ray ray = HandleUtility.GUIPointToWorldRay(mousePosition);
			mousePosition = ray.origin;
			mousePosition.z = 0;
			//Vector2Int index = FindNode(mousePosition, Vector2.up);
			Vector2Int index = Vector2Int.zero;

			if (isBaked)
            {
                int xCount = Mathf.FloorToInt((bounds.max.x - bounds.min.x) / pointInterval) + 1;
                int yCount = Mathf.FloorToInt((bounds.max.y - bounds.min.y) / pointInterval) + 1;
                foreach (var node in nodes.Values)
                {
					//Gizmos.color = new Color(node.gridPosition.x / (float)xCount, node.gridPosition.y / (float)yCount, 0, 1);
					if (index == node.gridPosition)
					{
						Gizmos.color = Color.yellow;
						Gizmos.DrawSphere(node.worldPositon, pointInterval / 5);
						foreach (var adj in node.adjacencies)
						{
							Gizmos.DrawLine(node.worldPositon, nodes[adj].worldPositon);
						}
					}
					else
					{
						Gizmos.color = Color.blue;
						Gizmos.DrawSphere(node.worldPositon, pointInterval / 5);
						Gizmos.color = new Color(0, 1, 0, .1f);
						foreach (var adj in node.adjacencies)
						{
							Gizmos.DrawLine(node.worldPositon, nodes[adj].worldPositon);
						}
					}

                }
            }
            else
            {
                Gizmos.DrawWireCube(bounds.center, bounds.size);

                for (float i = bounds.min.x; i <= bounds.max.x; i += pointInterval * 2)
                {
                    for (float j = bounds.min.y; j <= bounds.max.y; j += pointInterval * 2)
                    {
                        if (Physics2D.OverlapCircle(new Vector2(i, j), agentRadius, blockMask))
                            Gizmos.color = Color.red;
                        else
                            Gizmos.color = Color.green;
                        Gizmos.DrawSphere(new Vector3(i, j, 0), pointInterval / 5);
                    }
                }

                for (float i = bounds.min.x + pointInterval; i <= bounds.max.x; i += pointInterval * 2)
                {
                    for (float j = bounds.min.y + pointInterval; j <= bounds.max.y; j += pointInterval * 2)
                    {
                        if (Physics2D.OverlapCircle(new Vector2(i, j), agentRadius, blockMask))
							Gizmos.color = Color.red;
                        else
                            Gizmos.color = Color.green;
                        Gizmos.DrawSphere(new Vector3(i, j, 0), pointInterval / 5);
                    }
                }
            }
			Gizmos.color = Color.red;
			Gizmos.DrawSphere(IndexToWorld(WorldToIndex(mousePosition, Vector2.right)), pointInterval / 5);

			//Gizmos.color = Color.white;
			//Gizmos.DrawSphere(nodes[FindNode(agent.transform.position, mousePosition - agent.transform.position)].worldPositon, pointInterval);
			//Gizmos.DrawSphere(nodes[FindNode(mousePosition, agent.transform.position - mousePosition)].worldPositon, pointInterval);

		}

#endif
		private void Start()
		{
			if (!isBaked)
				BakeNodes();
		}

		public void BakeNodes()
        {
            Debug.Log("Baking Start");
            ClearNodes();
            int xCount = Mathf.FloorToInt((bounds.max.x - bounds.min.x) / pointInterval) + 1;
            int yCount = Mathf.FloorToInt((bounds.max.y - bounds.min.y) / pointInterval) + 1;

			complementLevel = complementLevel - complementLevel % 2;

            for (int i = 0; i < xCount; i++)
            {
                for (int j = 0; j < yCount; j++)
                {
                    if ((i + j) % 2 != 0)
                        continue;
                    Vector3 worldPosition = IndexToWorld(new Vector2Int(i, j));
                    if (Physics2D.OverlapCircle(worldPosition, agentRadius, blockMask))
                        continue;
                    Node newNode = new Node(new Vector2Int(i, j), worldPosition);
                    nodes.Add(newNode.gridPosition, newNode);
                }
                
            }
			if (isOptimizing)
			{
				List<Vector2Int> deleteList = new List<Vector2Int>();
				Dictionary<Vector2Int, Node> allNodes = new Dictionary<Vector2Int, Node>(nodes);

				for (int i = 0; i < xCount; i++)
				{
					for (int j = 0; j < yCount; j++)
					{
						Vector2Int pos = new Vector2Int(i, j);
						if (!nodes.ContainsKey(pos))
							continue;
						int count = 0;
						if (nodes.ContainsKey(pos + new Vector2Int(2, 0))) //Right
						{
							count++;
						}
						if (nodes.ContainsKey(pos + new Vector2Int(-2, 0))) //Left
						{
							count++;
						}
						if (nodes.ContainsKey(pos + new Vector2Int(0, 2))) // Up
						{
							count++;
						}
						if (nodes.ContainsKey(pos + new Vector2Int(0, -2))) // Down
						{
							count++;
						}

						if (nodes.ContainsKey(pos + new Vector2Int(-1, 1))) //UpLeft
						{
							count++;
						}
						if (nodes.ContainsKey(pos + new Vector2Int(1, 1))) //UpRight
						{
							count++;
						}
						if (nodes.ContainsKey(pos + new Vector2Int(-1, -1))) //DownLeft
						{
							count++;
						}
						if (nodes.ContainsKey(pos + new Vector2Int(1, -1))) //DownRight
						{
							count++;
						}

						if (count == 8)
						{
							deleteList.Add(pos);
						}
					}
				}

				foreach(var delete in deleteList)
				{
					nodes.Remove(delete);
				}

				for (int i = complementLevel; i < xCount; i+= complementLevel)
				{
					for (int j = complementLevel; j < yCount; j += complementLevel)
					{
						Vector2Int pos = new Vector2Int(i, j);
						if (nodes.ContainsKey(pos) || Physics2D.OverlapCircle(IndexToWorld(pos),agentRadius, blockMask))
							continue;
						Node node = new Node(pos, IndexToWorld(pos));
						nodes.Add(node.gridPosition, node);
					}
				}


				for (int i = 0; i < xCount; i++)
				{
					for (int j = 0; j < yCount; j++)
					{
						List<Vector2Int> adj = new List<Vector2Int>();
						Vector2Int pos = new Vector2Int(i, j);
						if (!nodes.ContainsKey(pos))
							continue;
						Node node = nodes[pos];
						for (int k = 1; (node.gridPosition + new Vector2Int(2,0) * k).x < xCount; k++)
						{
							if (!allNodes.ContainsKey(node.gridPosition + new Vector2Int(2, 0) * k))
								break;
							if (nodes.ContainsKey(node.gridPosition + new Vector2Int(2,0) * k))
							{
								if (!Physics2D.Linecast(nodes[node.gridPosition].worldPositon, nodes[node.gridPosition + new Vector2Int(2, 0) * k].worldPositon, blockMask))
									adj.Add(node.gridPosition + new Vector2Int(2, 0) * k);
								break;
							}
						}
						for (int k = 1; (node.gridPosition + new Vector2Int(-2, 0) * k).x >= 0; k++)
						{
							if (!allNodes.ContainsKey(node.gridPosition + new Vector2Int(-2, 0) * k))
								break;
							if (nodes.ContainsKey(node.gridPosition + new Vector2Int(-2, 0) * k))
							{
								if (!Physics2D.Linecast(nodes[node.gridPosition].worldPositon, nodes[node.gridPosition + new Vector2Int(-2, 0) * k].worldPositon, blockMask))
									adj.Add(node.gridPosition + new Vector2Int(-2, 0) * k);
								break;
							}
						}
						for (int k = 1; (node.gridPosition + new Vector2Int(0, 2) * k).y < yCount; k++)
						{
							if (!allNodes.ContainsKey(node.gridPosition + new Vector2Int(0, 2) * k))
								break;
							if (nodes.ContainsKey(node.gridPosition + new Vector2Int(0, 2) * k))
							{
								if (!Physics2D.Linecast(nodes[node.gridPosition].worldPositon, nodes[node.gridPosition + new Vector2Int(0, 2) * k].worldPositon, blockMask))
									adj.Add(node.gridPosition + new Vector2Int(0, 2) * k);
								break;
							}
						}
						for (int k = 1; (node.gridPosition + new Vector2Int(0, -2) * k).y >= 0; k++)
						{
							if (!allNodes.ContainsKey(node.gridPosition + new Vector2Int(0, -2) * k))
								break;
							if (nodes.ContainsKey(node.gridPosition + new Vector2Int(0, -2) * k))
							{
								if (!Physics2D.Linecast(nodes[node.gridPosition].worldPositon, nodes[node.gridPosition + new Vector2Int(0, -2) * k].worldPositon, blockMask))
									adj.Add(node.gridPosition + new Vector2Int(0, -2) * k);
								break;
							}
						}

						for (int k = 1;
							(node.gridPosition + new Vector2Int(-1, 1) * k).x >= 0 &&
							(node.gridPosition + new Vector2Int(-1, 1) * k).y < yCount; k++)
						{
							if (!allNodes.ContainsKey(node.gridPosition + new Vector2Int(-1, 1) * k))
								break;
							if (nodes.ContainsKey(node.gridPosition + new Vector2Int(-1, 1) * k))
							{
								if (!Physics2D.Linecast(nodes[node.gridPosition].worldPositon, nodes[node.gridPosition + new Vector2Int(-1, 1) * k].worldPositon, blockMask))
									adj.Add(node.gridPosition + new Vector2Int(-1, 1) * k);
								break;
							}
						}
						for (int k = 1;
							(node.gridPosition + new Vector2Int(1, 1) * k).x < xCount &&
							(node.gridPosition + new Vector2Int(1, 1) * k).y < yCount; k++)
						{
							if (!allNodes.ContainsKey(node.gridPosition + new Vector2Int(1, 1) * k))
								break;
							if (nodes.ContainsKey(node.gridPosition + new Vector2Int(1, 1) * k))
							{
								if (!Physics2D.Linecast(nodes[node.gridPosition].worldPositon, nodes[node.gridPosition + new Vector2Int(1, 1) * k].worldPositon, blockMask))
									adj.Add(node.gridPosition + new Vector2Int(1, 1) * k);
								break;
							}
						}
						for (int k = 1;
							(node.gridPosition + new Vector2Int(-1, -1) * k).x >= 0 &&
							(node.gridPosition + new Vector2Int(-1, -1) * k).y >= 0; k++)
						{
							if (!allNodes.ContainsKey(node.gridPosition + new Vector2Int(-1, -1) * k))
								break;
							if (nodes.ContainsKey(node.gridPosition + new Vector2Int(-1, -1) * k))
							{
								if (!Physics2D.Linecast(nodes[node.gridPosition].worldPositon, nodes[node.gridPosition + new Vector2Int(-1, -1) * k].worldPositon, blockMask))
									adj.Add(node.gridPosition + new Vector2Int(-1, -1) * k);
								break;
							}
						}
						for (int k = 1;
							(node.gridPosition + new Vector2Int(1, -1) * k).x < xCount &&
							(node.gridPosition + new Vector2Int(1, -1) * k).y >= 0; k++)
						{
							if (!allNodes.ContainsKey(node.gridPosition + new Vector2Int(1, -1) * k))
								break;
							if (nodes.ContainsKey(node.gridPosition + new Vector2Int(1, -1) * k))
							{
								if (!Physics2D.Linecast(nodes[node.gridPosition].worldPositon, nodes[node.gridPosition + new Vector2Int(1, -1) * k].worldPositon, blockMask))
									adj.Add(node.gridPosition + new Vector2Int(1, -1) * k);
								break;
							}
						}
						node.adjacencies = adj;
					}
				}
			}
			else
			{
				for (int i = 0; i < xCount; i++)
				{
					for (int j = 0; j < yCount; j++)
					{
						List<Vector2Int> adj = new List<Vector2Int>();
						Vector2Int pos = new Vector2Int(i, j);
						if (!nodes.ContainsKey(pos))
							continue;
						Node node = nodes[pos];
						if (nodes.ContainsKey(node.gridPosition + new Vector2Int(2, 0))) //Right
						{
							if(!Physics2D.Linecast(nodes[node.gridPosition].worldPositon, nodes[node.gridPosition + new Vector2Int(2, 0)].worldPositon, blockMask))
								adj.Add(node.gridPosition + new Vector2Int(2, 0));
						}
						if (nodes.ContainsKey(node.gridPosition + new Vector2Int(-2, 0))) //Left
						{
							if (!Physics2D.Linecast(nodes[node.gridPosition].worldPositon, nodes[node.gridPosition + new Vector2Int(-2, 0)].worldPositon, blockMask))
								adj.Add(node.gridPosition + new Vector2Int(-2, 0));
						}
						if (nodes.ContainsKey(node.gridPosition + new Vector2Int(0, 2))) // Up
						{
							if (!Physics2D.Linecast(nodes[node.gridPosition].worldPositon, nodes[node.gridPosition + new Vector2Int(0, 2)].worldPositon, blockMask))
								adj.Add(node.gridPosition + new Vector2Int(0, 2));
						}
						if (nodes.ContainsKey(node.gridPosition + new Vector2Int(0, -2))) // Down
						{
							if (!Physics2D.Linecast(nodes[node.gridPosition].worldPositon, nodes[node.gridPosition + new Vector2Int(0, -2)].worldPositon, blockMask))
								adj.Add(node.gridPosition + new Vector2Int(0, -2));
						}

						if (nodes.ContainsKey(node.gridPosition + new Vector2Int(-1, 1))) //UpLeft
						{
							if (!Physics2D.Linecast(nodes[node.gridPosition].worldPositon, nodes[node.gridPosition + new Vector2Int(-1, 1)].worldPositon, blockMask))
								adj.Add(node.gridPosition + new Vector2Int(-1, 1));
						}
						if (nodes.ContainsKey(node.gridPosition + new Vector2Int(1, 1))) //UpRight
						{
							if (!Physics2D.Linecast(nodes[node.gridPosition].worldPositon, nodes[node.gridPosition + new Vector2Int(1, 1)].worldPositon, blockMask))
								adj.Add(node.gridPosition + new Vector2Int(1, 1));
						}
						if (nodes.ContainsKey(node.gridPosition + new Vector2Int(-1, -1))) //DownLeft
						{
							if (!Physics2D.Linecast(nodes[node.gridPosition].worldPositon, nodes[node.gridPosition + new Vector2Int(-1, -1)].worldPositon, blockMask))
								adj.Add(node.gridPosition + new Vector2Int(-1, -1));
						}
						if (nodes.ContainsKey(node.gridPosition + new Vector2Int(1, -1))) //DownRight
						{
							if (!Physics2D.Linecast(nodes[node.gridPosition].worldPositon, nodes[node.gridPosition + new Vector2Int(1, -1)].worldPositon, blockMask))
								adj.Add(node.gridPosition + new Vector2Int(1, -1));
						}
						node.adjacencies = adj;
						nodes[pos] = node;
					}
				}
			}
            Debug.Log("Baking End");
        }

        public List<Vector3> GetPath(Vector3 from, Vector3 to)
        {
            return null;
        }

        public List<Vector3> GetPathGreedy(Vector3 from, Vector3 to)
        {
            List<Vector3> path = new List<Vector3>();
            if (WorldToIndex(from, to - from).x < 0 || WorldToIndex(to, from - to).x < 0)
            {
                return path;
            }

            Node cur =  nodes[WorldToIndex(from, to - from)];
            path.Add(cur.worldPositon);

            Queue<Node> off = new Queue<Node>();
            Node next;
            while (Vector2.Distance(cur.worldPositon, to) > pointInterval * 2)
            {
                off.Enqueue(cur);
                next = cur;
                float min = Vector2.Distance(cur.worldPositon, to);
                foreach (var adj in cur.adjacencies)
                {
                    float score = Vector2.Distance(nodes[adj].worldPositon, to);
                    if (min > score)
                    {
                        next = nodes[adj];
                        min = score;
                    }
                }
                if (cur.Equals(next))
                {
                    return path;
                }
                path.Add(next.worldPositon);
                cur = next;
            }
            path.Add(to);
            return path;
        }

        public List<Vector3> GetPathAstar(Vector3 from, Vector3 to)
        {
            List<Vector3> path = new List<Vector3>();

			Vector2Int startIndex = FindNode(from, to - from);
			Vector2Int goalIndex = FindNode(to, from - to);

			if (startIndex.x <0 || goalIndex.x < 0)
			{
				return path;
			}

			Node start = nodes[startIndex];
			Node goal = nodes[goalIndex];


			PriorityQueue<Node> queue = new PriorityQueue<Node>((a, b) => a.score < b.score ? 1 : a.score == b.score ? 0 : -1);
            List<Node> closed = new List<Node>();

            queue.Enqueue(start);
			start.state = 1;
            while (queue.Count > 0)
            {
                Node cur = queue.Dequeue();
				/*
				Vector2Int newGoalIndex = FindNode(to, cur.worldPositon - to);
				if (newGoalIndex.x > 0)
					goal = nodes[FindNode(to, cur.worldPositon - to)];
				*/
                if(cur.gridPosition == goal.gridPosition)
                {
                    closed.Add(cur);
					cur.state = -1;
                    break;
                }
				

                foreach (var adj in cur.adjacencies)
                {
                    Node adjNode = nodes[adj];
					bool isOpened = nodes[adj].state > 0;
					bool isClosed = nodes[adj].state < 0;

                    if (isOpened)
                    {
                        if (adjNode.cost > cur.cost + Vector2.Distance(cur.worldPositon, adjNode.worldPositon))
                        {
                            adjNode.cost = cur.cost + Vector2.Distance(cur.worldPositon, adjNode.worldPositon);
                            adjNode.CalculateScore(to);
                            adjNode.parent = cur;
							queue.ValueUpdated(adjNode);
                        }
                    }
                    else if (!isClosed)
                    {
						adjNode.parent = cur;
						adjNode.cost = cur.cost + Vector2.Distance(cur.worldPositon, adjNode.worldPositon);
						adjNode.CalculateScore(to);
                        queue.Enqueue(adjNode);
						adjNode.state = 1;
                    }

                }
				cur.state = -1;
                closed.Add(cur);
            }

            Node tmp = closed.Find(n => n.gridPosition == goal.gridPosition);
            if (tmp == null)    //Path is blocked
            {
                return GetPathGreedy(from, to);
            }
            else
            {
                while (tmp.gridPosition != start.gridPosition)
                {
                    path.Add(tmp.worldPositon);
                    tmp = tmp.parent;
                }
            }
            path.Reverse();
			if (path.Count > 1 && Physics2D.OverlapCircle(to, agentRadius, blockMask))
			{
				path.Add(Physics2D.Raycast(path[path.Count - 1], to - path[path.Count - 1], float.MaxValue, blockMask).point);

			}
			else
			{
				path.Add(to);
			}

			foreach(var close in closed)
			{
				close.state = 0;
			}
			foreach(var open in queue)
			{
				open.state = 0;
			}

			return path;
        }

        private Vector3 IndexToWorld(Vector2Int index)
        {
            return IndexToWorld(index.x, index.y);
        }

        private Vector3 IndexToWorld(int x, int y)
        {
            return bounds.min + new Vector3(x * pointInterval, y * pointInterval);
        }

        private Vector2Int WorldToIndex(Vector3 position, Vector2 direction)
        {
            int minX = Mathf.FloorToInt((position.x - bounds.min.x) / pointInterval);
            int maxX = Mathf.CeilToInt((position.x - bounds.min.x) / pointInterval);
            int minY = Mathf.FloorToInt((position.y - bounds.min.y) / pointInterval);
            int maxY = Mathf.CeilToInt((position.y - bounds.min.y) / pointInterval);

			Vector2Int[] adjs = { new Vector2Int(minX, minY), new Vector2Int(minX, maxY), new Vector2Int(maxX, minY), new Vector2Int(maxX, maxY) };
            Vector2Int index = Vector2Int.zero;
            float max = float.MinValue;

            foreach (var adj in adjs)
            {
                float score = Vector2.Dot((IndexToWorld(adj) - position).normalized, direction.normalized) / (IndexToWorld(adj) - position).magnitude;
                if (max < score)
                {
                    index = adj;
                    max = score;
                }
            }
            return index;
        }

		private Vector2Int FindNode(Vector3 position, Vector2 direction)
		{
			if (isOptimizing)
			{
				Vector2Int index = WorldToIndex(position, direction);
				if (nodes.ContainsKey(index))
					return index;
				Vector2Int target = index;
				float max = float.MinValue;
	
				for (int i = index.x - index.x % complementLevel; i <= index.x + complementLevel - index.x % complementLevel; i++)
				{
					for (int j = index.y - index.y % complementLevel; j <= index.y + complementLevel - index.y % complementLevel; j++)
					{
						
						Vector2Int key = new Vector2Int(i, j);
						if (nodes.ContainsKey(key))
						{
							float score = Vector2.Dot(IndexToWorld(key) - position, direction.normalized) / (IndexToWorld(key) - position).magnitude;
							if (max < score)
							{
								target = key;
								max = score;
							}
						}
					}
				}
				return target;
			}
			else
			{
				Vector2Int index = WorldToIndex(position, direction);
				Vector2Int init = index;
				if (nodes.ContainsKey(index))
				{
					return index;
				}
				/*
				for (int i = 1; index == WorldToIndex(position, direction) ; i++)
				{
					float max = float.MinValue;
					for (int j = -i; j <= i; j++)
					{
						for (int k = -i; k <= i; k++)
						{
							Vector2Int tmp = new Vector2Int(j, k);
							if (nodes.ContainsKey(tmp))
							{
								float score = Vector2.Dot((IndexToWorld(tmp) - position).normalized, direction.normalized) / (IndexToWorld(tmp) - position).magnitude;
								if (max < score)
								{
									max = score;
									index = tmp;
								}
							}
						}
					}
				}
				*/
				Vector2Int[] adjs = { index + Vector2Int.up, index + Vector2Int.down, index + Vector2Int.right, index + Vector2Int.left };
				float max = float.MinValue;
				foreach (var adj in adjs)
				{
					if (!nodes.ContainsKey(adj))
						continue;
					float score = Vector2.Dot((IndexToWorld(adj) - position).normalized, direction.normalized) / (IndexToWorld(adj) - position).magnitude;
					if (max < score)
					{
						index = adj;
						max = score;
					}
				}
				if (index == init)
				{
					return new Vector2Int(-1, -1);
				}
				return index;
			}
		}

        public void ClearNodes()
        {
            nodes.Clear();
        }
    }

    class Node
    {
        public Node(Vector2Int gridPosition, Vector3 worldPositon)
        {
            this.gridPosition = gridPosition;
            this.worldPositon = worldPositon;
            adjacencies = null;
            cost = 0;
            score = 0;
            parent = null;
			state = 0;
        }

        public Vector2Int gridPosition;
        public Vector3 worldPositon;
        public List<Vector2Int> adjacencies;
        public float cost;
        public float score;
        public Node parent;
		public int state; // -1 : closed, 1 : opened, 0 : neutral
        public void CalculateScore(Vector3 destination)
        {
            score = cost + GetHeuristic(destination);
        }

        private float GetHeuristic(Vector3 destination)
        {
            return Vector2.Distance(worldPositon, destination);
        }
    }

    class NodeComparer : IComparer<Node>
    {
        public int Compare(Node x, Node y)
        {
            throw new System.NotImplementedException();
        }
    }
}