﻿using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using UnityEditor;

namespace BS
{
    public class NodalPathfinding2D : MonoBehaviour
    {
        public Bounds bounds;
        [Range(0.1f, 1f)]
        public float pointInterval = 0.5f;

        
        public LayerMask blockMask;
        private Dictionary<Vector2Int, Node> nodes = new Dictionary<Vector2Int, Node>();

        public bool isBaked { get { return nodes.Count > 0; } }

        public bool enableGizmos = true;


        private void OnDrawGizmos()
        {
            if (!enableGizmos)
                return;
            if (isBaked)
            {
                int xCount = Mathf.FloorToInt((bounds.max.x - bounds.min.x) / pointInterval) + 1;
                int yCount = Mathf.FloorToInt((bounds.max.y - bounds.min.y) / pointInterval) + 1;
                Gizmos.color = new Color(0, 1, 0, .3f);
                foreach (var node in nodes.Values)
                {
                    //Gizmos.color = new Color(node.gridPosition.x / (float)xCount, node.gridPosition.y / (float)yCount, 0, 1);
                    Gizmos.DrawSphere(node.worldPositon, pointInterval / 5);
                    foreach (var adj in node.adjacencies)
                    {
                        Gizmos.DrawLine(node.worldPositon, nodes[adj].worldPositon);
                    }
                }
            }
            else
            {
                Gizmos.DrawWireCube(bounds.center, bounds.size);

                for (float i = bounds.min.x; i <= bounds.max.x; i += pointInterval * 2)
                {
                    for (float j = bounds.min.y; j <= bounds.max.y; j += pointInterval * 2)
                    {
                        if (Physics2D.OverlapPoint(new Vector2(i, j), blockMask))
                            Gizmos.color = Color.red;
                        else
                            Gizmos.color = Color.green;
                        Gizmos.DrawSphere(new Vector3(i, j, 0), pointInterval / 5);
                    }
                }

                for (float i = bounds.min.x + pointInterval; i <= bounds.max.x; i += pointInterval * 2)
                {
                    for (float j = bounds.min.y + pointInterval; j <= bounds.max.y; j += pointInterval * 2)
                    {
                        if (Physics2D.OverlapPoint(new Vector2(i, j), blockMask))
                            Gizmos.color = Color.red;
                        else
                            Gizmos.color = Color.green;
                        Gizmos.DrawSphere(new Vector3(i, j, 0), pointInterval / 5);
                    }
                }
            }
        }

        public void BakeNodes()
        {
            Debug.Log("Baking Start");
            ClearNodes();
            int xCount = Mathf.FloorToInt((bounds.max.x - bounds.min.x) / pointInterval) + 1;
            int yCount = Mathf.FloorToInt((bounds.max.y - bounds.min.y) / pointInterval) + 1;

            for (int i = 0; i < xCount; i++)
            {
                for (int j = 0; j < yCount; j++)
                {
                    if ((i + j) % 2 != 0)
                        continue;
                    Vector3 worldPosition = IndexToWorld(new Vector2Int(i, j));
                    if (Physics2D.OverlapPoint(worldPosition, blockMask))
                        continue;
                    Node newNode = new Node(new Vector2Int(i, j), worldPosition);
                    nodes.Add(newNode.gridPosition, newNode);
                }
                
            }

            for (int i = 0; i < xCount; i++)
            {
                for (int j = 0; j < yCount; j++)
                {
                    List<Vector2Int> adj = new List<Vector2Int>();
                    Vector2Int pos = new Vector2Int(i, j);
                    if (!nodes.ContainsKey(pos))
                        continue;
                    Node node = nodes[pos];
                    if (nodes.ContainsKey(node.gridPosition + new Vector2Int(2, 0))) //Right
                    {
                        adj.Add(node.gridPosition + new Vector2Int(2, 0));
                    }
                    if (nodes.ContainsKey(node.gridPosition + new Vector2Int(-2, 0))) //Left
                    {
                        adj.Add(node.gridPosition + new Vector2Int(-2, 0));
                    }
                    if (nodes.ContainsKey(node.gridPosition + new Vector2Int(0, 2))) // Up
                    {
                        adj.Add(node.gridPosition + new Vector2Int(0, 2));
                    }
                    if (nodes.ContainsKey(node.gridPosition + new Vector2Int(0, -2))) // Down
                    {
                        adj.Add(node.gridPosition + new Vector2Int(0, -2));
                    }

                    if (nodes.ContainsKey(node.gridPosition + new Vector2Int(-1, 1))) //UpLeft
                    {
                        adj.Add(node.gridPosition + new Vector2Int(-1, 1));
                    }
                    if (nodes.ContainsKey(node.gridPosition + new Vector2Int(1, 1))) //UpRight
                    {
                        adj.Add(node.gridPosition + new Vector2Int(1, 1));
                    }
                    if (nodes.ContainsKey(node.gridPosition + new Vector2Int(-1, -1))) //DownLeft
                    {
                        adj.Add(node.gridPosition + new Vector2Int(-1, -1));
                    }
                    if (nodes.ContainsKey(node.gridPosition + new Vector2Int(1, -1))) //DownRight
                    {
                        adj.Add(node.gridPosition + new Vector2Int(1, -1));
                    }
                    node.adjacencies = adj;
                    nodes[pos] = node;
                }
            }
            Debug.Log("Baking End");
        }

        public List<Vector3> GetPath(Vector3 from, Vector3 to)
        {
            return null;
        }

        public List<Vector3> GetPathGreedy(Vector3 from, Vector3 to)
        {
            List<Vector3> path = new List<Vector3>();
            if (WorldToIndex(from, to - from).x < 0 || WorldToIndex(to, from - to).x < 0)
            {
                return path;
            }

            Node cur = nodes[WorldToIndex(from, to - from)];
            Node goal = nodes[WorldToIndex(to, from - to)];
            path.Add(cur.worldPositon);

            Queue<Node> off = new Queue<Node>();
            Node next;
            while (Vector2.Distance(cur.worldPositon, to) > pointInterval * 2)
            {
                off.Enqueue(cur);
                next = cur;
                float min = Vector2.Distance(cur.worldPositon, to);
                foreach (var adj in cur.adjacencies)
                {
                    float score = Vector2.Distance(nodes[adj].worldPositon, to);
                    if (min > score)
                    {
                        next = nodes[adj];
                        min = score;
                    }
                }
                if (cur.Equals(next))
                {
                    return path;
                }
                path.Add(next.worldPositon);
                cur = next;
            }
            path.Add(to);
            return path;
        }

        public List<Vector3> GetPathAStar(Vector3 from, Vector3 to)
        {
            
            List<Vector3> path = new List<Vector3>();

            if (WorldToIndex(from, to - from).x < 0 || WorldToIndex(to, from - to).x < 0)
            {
                return path;
            }
            
            

            return path;
        }

        private Vector3 IndexToWorld(Vector2Int index)
        {
            return IndexToWorld(index.x, index.y);
        }

        private Vector3 IndexToWorld(int x, int y)
        {
            return bounds.min + new Vector3(x * pointInterval, y * pointInterval);
        }

        private Vector2Int WorldToIndex(Vector3 position, Vector2 direction)
        {
            int minX = Mathf.FloorToInt((position.x - bounds.min.x) / pointInterval);
            int maxX = Mathf.CeilToInt((position.x - bounds.min.x) / pointInterval);
            int minY = Mathf.FloorToInt((position.y - bounds.min.y) / pointInterval);
            int maxY = Mathf.CeilToInt((position.y - bounds.min.y) / pointInterval);

            Vector2Int center;
            if ((minX + minY) % 2 != 0)
            {
                center =
                    Vector2.Distance(IndexToWorld(minX, minY), position) > Vector2.Distance(IndexToWorld(maxX, maxY), position) ?
                    new Vector2Int(maxX, maxY) :
                    new Vector2Int(minX, minY);
            }
            else
            {
                center =
                    Vector2.Distance(IndexToWorld(minX, maxY), position) > Vector2.Distance(IndexToWorld(maxX, minY), position) ?
                    new Vector2Int(maxX, minY) :
                    new Vector2Int(minX, maxY);
            }

            Vector2Int[] adjs = { center + Vector2Int.up, center + Vector2Int.down, center + Vector2Int.left, center + Vector2Int.right };

            Vector2Int index = new Vector2Int(-1, -1);
            float max = float.MinValue;

            foreach (var adj in adjs)
            {
                if (!nodes.ContainsKey(adj))
                    continue;
                float score = Vector2.Dot(nodes[adj].worldPositon - position, direction.normalized);
                if (max < score)
                {
                    index = adj;
                    max = score;
                }
            }
            if (index.x < 0)
                Debug.Log("Blocked!!");
            return index;
        }

        public void ClearNodes()
        {
            nodes.Clear();
        }
    }

    struct Node
    {
        public Node(Vector2Int gridPosition, Vector3 worldPositon)
        {
            this.gridPosition = gridPosition;
            this.worldPositon = worldPositon;
            adjacencies = null;
        }
        public Vector2Int gridPosition;
        public Vector3 worldPositon;
        public List<Vector2Int> adjacencies;
    }

    class NodeComparer : IComparer<Node>
    {
        private Vector3 destination;
        NodeComparer(Vector3 destination)
        {
            this.destination = destination;
        }
    }
}