##########################################################################
#
# pgAdmin 4 - PostgreSQL Tools
#
# Copyright (C) 2013 - 2019, The pgAdmin Development Team
# This software is released under the PostgreSQL Licence
#
##########################################################################

"""
Implementation of Driver class
It is a wrapper around the actual psycopg2 driver, and connection
object.

"""
import datetime
from flask import session, request
from flask_login import current_user
from flask_babelex import gettext
import psycopg2
from psycopg2.extensions import adapt

import config
from pgadmin.model import Server, User
from .keywords import ScanKeyword
from ..abstract import BaseDriver
from .connection import Connection
from .server_manager import ServerManager


class Driver(BaseDriver):
    """
    class Driver(BaseDriver):

    This driver acts as a wrapper around psycopg2 connection driver
    implementation. We will be using psycopg2 for makeing connection with
    the PostgreSQL/EDB Postgres Advanced Server (EnterpriseDB).

    Properties:
    ----------

    * Version (string):
        Version of psycopg2 driver

    Methods:
    -------
    * get_connection(sid, database, conn_id, auto_reconnect)
    - It returns a Connection class object, which may/may not be connected
      to the database server for this sesssion

    * release_connection(seid, database, conn_id)
    - It releases the connection object for the given conn_id/database for this
      session.

    * connection_manager(sid, reset)
    - It returns the server connection manager for this session.
    """

    def __init__(self, **kwargs):
        self.managers = dict()

        super(Driver, self).__init__()

    def connection_manager(self, sid=None):
        """
        connection_manager(...)

        Returns the ServerManager object for the current session. It will
        create new ServerManager object (if necessary).

        Parameters:
            sid
            - Server ID
        """
        assert (sid is not None and isinstance(sid, int))
        managers = None

        server_data = Server.query.filter_by(id=sid).first()
        if server_data is None:
            return None

        if session.sid not in self.managers:
            self.managers[session.sid] = managers = dict()
            if '__pgsql_server_managers' in session:
                session_managers = session['__pgsql_server_managers'].copy()

                for server in Server.query.filter_by(user_id=current_user.id):
                    manager = managers[str(server.id)] = ServerManager(server)
                    if server.id in session_managers:
                        manager._restore(session_managers[server.id])
                        manager.update_session()
        else:
            managers = self.managers[session.sid]
            if str(sid) in managers:
                manager = managers[str(sid)]
                manager._restore_connections()
                manager.update_session()

        managers['pinged'] = datetime.datetime.now()
        if str(sid) not in managers:
            s = Server.query.filter_by(id=sid).first()

            if not s:
                return None

            managers[str(sid)] = ServerManager(s)

            return managers[str(sid)]

        return managers[str(sid)]

    def Version(cls):
        """
        Version(...)

        Returns the current version of psycopg2 driver
        """
        version = getattr(psycopg2, '__version__', None)

        if version:
            return version

        raise Exception(
            "Driver Version information for psycopg2 is not available!"
        )

    def libpq_version(cls):
        """
        Returns the loaded libpq version
        """
        version = getattr(psycopg2, '__libpq_version__', None)
        if version:
            return version

        raise Exception(
            "libpq version information is not available!"
        )

    def get_connection(
            self, sid, database=None, conn_id=None, auto_reconnect=True
    ):
        """
        get_connection(...)

        Returns the connection object for the certain connection-id/database
        for the specific server, identified by sid. Create a new Connection
        object (if require).

        Parameters:
            sid
            - Server ID
            database
            - Database, on which the connection needs to be made
              If provided none, maintenance_db for the server will be used,
              while generating new Connection object.
            conn_id
            - Identification String for the Connection This will be used by
              certain tools, which will require a dedicated connection for it.
              i.e. Debugger, Query Tool, etc.
            auto_reconnect
            - This parameters define, if we should attempt to reconnect the
              database server automatically, when connection has been lost for
              any reason. Certain tools like Debugger will require a permenant
              connection, and it stops working on disconnection.

        """
        manager = self.connection_manager(sid)

        return manager.connection(database, conn_id, auto_reconnect)

    def release_connection(self, sid, database=None, conn_id=None):
        """
        Release the connection for the given connection-id/database in this
        session.
        """
        return self.connection_manager(sid).release(database, conn_id)

    def delete_manager(self, sid):
        """
        Delete manager for given server id.
        """
        manager = self.connection_manager(sid)
        if manager is not None:
            manager.release()
        if session.sid in self.managers and \
                str(sid) in self.managers[session.sid]:
            del self.managers[session.sid][str(sid)]

    def gc(self):
        """
        Release the connections for the sessions, which have not pinged the
        server for more than config.MAX_SESSION_IDLE_TIME.
        """

        # Minimum session idle is 20 minutes
        max_idle_time = max(config.MAX_SESSION_IDLE_TIME or 60, 20)
        session_idle_timeout = datetime.timedelta(minutes=max_idle_time)

        curr_time = datetime.datetime.now()

        for sess in self.managers:
            sess_mgr = self.managers[sess]

            if sess == session.sid:
                sess_mgr['pinged'] = curr_time
                continue
            if curr_time - sess_mgr['pinged'] >= session_idle_timeout:
                for mgr in [
                    m for m in sess_mgr.values() if isinstance(m,
                                                               ServerManager)
                ]:
                    mgr.release()

    def gc_own(self):
        """
        Release the connections for current session
        This is useful when (eg. logout) we want to release all
        connections (except dedicated connections created by utilities
        like backup, restore etc) of all servers for current user.
        """

        sess_mgr = self.managers.get(session.sid, None)

        if sess_mgr:
            for mgr in (
                m for m in sess_mgr.values() if isinstance(m, ServerManager)
            ):
                mgr.release()

    @staticmethod
    def qtLiteral(value, forceQuote=False):
        adapted = adapt(value)

        # Not all adapted objects have encoding
        # e.g.
        # psycopg2.extensions.BOOLEAN
        # psycopg2.extensions.FLOAT
        # psycopg2.extensions.INTEGER
        # etc...
        if hasattr(adapted, 'encoding'):
            adapted.encoding = 'utf8'
        res = adapted.getquoted()

        if isinstance(res, bytes):
            res = res.decode('utf-8')

        if forceQuote is True:
            # Convert the input to the string to use the startsWith(...)
            res = str(res)
            if not res.startswith("'"):
                return "'" + res + "'"

        return res

    @staticmethod
    def ScanKeywordExtraLookup(key):
        # UNRESERVED_KEYWORD      0
        # COL_NAME_KEYWORD        1
        # TYPE_FUNC_NAME_KEYWORD  2
        # RESERVED_KEYWORD        3
        extraKeywords = {
            'connect': 3,
            'convert': 3,
            'distributed': 0,
            'exec': 3,
            'log': 0,
            'long': 3,
            'minus': 3,
            'nocache': 3,
            'number': 3,
            'package': 3,
            'pls_integer': 3,
            'raw': 3,
            'return': 3,
            'smalldatetime': 3,
            'smallfloat': 3,
            'smallmoney': 3,
            'sysdate': 3,
            'systimestap': 3,
            'tinyint': 3,
            'tinytext': 3,
            'varchar2': 3
        }

        return extraKeywords.get(key, None) or ScanKeyword(key)

    @staticmethod
    def needsQuoting(key, forTypes):
        value = key
        valNoArray = value

        # check if the string is number or not
        if isinstance(value, int):
            return True
        # certain types should not be quoted even though it contains a space.
        # Evilness.
        elif forTypes and value[-2:] == u"[]":
            valNoArray = value[:-2]

        if forTypes and valNoArray.lower() in [
            u'bit varying',
            u'"char"',
            u'character varying',
            u'double precision',
            u'timestamp without time zone',
            u'timestamp with time zone',
            u'time without time zone',
            u'time with time zone',
            u'"trigger"',
            u'"unknown"'
        ]:
            return False

        # If already quoted?, If yes then do not quote again
        if forTypes and valNoArray:
            if valNoArray.startswith('"') \
                    or valNoArray.endswith('"'):
                return False

        if u'0' <= valNoArray[0] <= u'9':
            return True

        for c in valNoArray:
            if (not (u'a' <= c <= u'z') and c != u'_' and
                    not (u'0' <= c <= u'9')):
                return True

        # check string is keywaord or not
        category = Driver.ScanKeywordExtraLookup(value)

        if category is None:
            return False

        # UNRESERVED_KEYWORD
        if category == 0:
            return False

        # COL_NAME_KEYWORD
        if forTypes and category == 1:
            return False

        return True

    @staticmethod
    def qtTypeIdent(conn, *args):
        # We're not using the conn object at the moment, but - we will
        # modify the
        # logic to use the server version specific keywords later.
        res = None
        value = None

        for val in args:
            # DataType doesn't have len function then convert it to string
            if not hasattr(val, '__len__'):
                val = str(val)

            if len(val) == 0:
                continue
            if hasattr(str, 'decode') and not isinstance(val, unicode):
                # Handling for python2
                try:
                    val = str(val).encode('utf-8')
                except UnicodeDecodeError:
                    # If already unicode, most likely coming from db
                    val = str(val).decode('utf-8')
            value = val

            if Driver.needsQuoting(val, True):
                value = value.replace("\"", "\"\"")
                value = "\"" + value + "\""

            res = ((res and res + '.') or '') + value

        return res

    @staticmethod
    def qtIdent(conn, *args):
        # We're not using the conn object at the moment, but - we will
        # modify the logic to use the server version specific keywords later.
        res = None
        value = None

        for val in args:
            if type(val) == list:
                return map(lambda w: Driver.qtIdent(conn, w), val)

            # DataType doesn't have len function then convert it to string
            if not hasattr(val, '__len__'):
                val = str(val)

            if hasattr(str, 'decode') and not isinstance(val, unicode):
                # Handling for python2
                try:
                    val = str(val).encode('utf-8')
                except UnicodeDecodeError:
                    # If already unicode, most likely coming from db
                    val = str(val).decode('utf-8')

            if len(val) == 0:
                continue

            value = val

            if Driver.needsQuoting(val, False):
                value = value.replace("\"", "\"\"")
                value = "\"" + value + "\""

            res = ((res and res + '.') or '') + value

        return res
