##########################################################################
#
# pgAdmin 4 - PostgreSQL Tools
#
# Copyright (C) 2013 - 2019, The pgAdmin Development Team
# This software is released under the PostgreSQL Licence
#
##########################################################################

from __future__ import print_function
import sys
import random

from regression.python_test_utils import test_utils
from regression.feature_utils.base_feature_test import BaseFeatureTest
from selenium.webdriver import ActionChains
from selenium.common.exceptions import StaleElementReferenceException
from regression.feature_utils.locators import QueryToolLocators


class CheckForXssFeatureTest(BaseFeatureTest):
    """
    Tests to check if pgAdmin4 is vulnerable to XSS.

    Here we will check html source code for escaped characters if we
    found them in the code then we are not vulnerable otherwise we might.

    We will cover,
        1) Browser Tree (aciTree)
        2) Properties Tab (BackFrom)
        3) Dependents Tab (BackGrid)
        4) SQL Tab (Code Mirror)
        5) Query Tool (SlickGrid)
    """

    scenarios = [
        ("Test XSS check for panels and query tool", dict())
    ]
    test_table_name = "<h1>X" + str(random.randint(1000, 3000))
    # test_table_name = "<h1>X"
    test_type_name = '"<script>alert(1)</script>"'

    def before(self):
        test_utils.create_type(
            self.server, self.test_db, self.test_type_name,
            ['"<script>alert(1)</script>" "char"',
             '"1<script>alert(1)</script>" "char"']
        )
        test_utils.create_table(
            self.server, self.test_db, self.test_table_name,
            ['"<script>alert(1)</script>" char',
             'typcol ' + self.test_type_name]
        )
        # This is needed to test dependents tab (eg: BackGrid)
        test_utils.create_constraint(
            self.server, self.test_db,
            self.test_table_name,
            "unique", "<h1 onmouseover='console.log(2);'>Y"
        )

    def runTest(self):
        self.page.wait_for_spinner_to_disappear()
        self.page.add_server(self.server)
        self._tables_node_expandable()
        self._check_xss_in_browser_tree()
        self._check_xss_in_properties_tab()
        self._check_xss_in_sql_tab()

        # sometime the tab for dependent does not show info, so refreshing
        # the page and then again collapsing until the table node
        self.page.refresh_page()
        self.page.toggle_open_servers_group()
        self._tables_node_expandable()
        self._check_xss_in_dependents_tab()

        # Query tool
        self.page.open_query_tool()
        self._check_xss_in_query_tool()
        self._check_xss_in_query_tool_history()
        self.page.close_query_tool()
        # Query tool view/edit data
        self.page.open_view_data(self.test_db)
        self._check_xss_view_data()
        self.page.close_data_grid()

        # Explain module
        self.page.open_query_tool()
        self._check_xss_in_explain_module()
        self.page.close_query_tool()

    def after(self):
        self.page.remove_server(self.server)
        test_utils.delete_table(
            self.server, self.test_db, self.test_table_name)

    def _tables_node_expandable(self):
        self.page.toggle_open_server(self.server['name'])
        self.page.toggle_open_tree_item('Databases')
        self.page.toggle_open_tree_item(self.test_db)
        self.page.toggle_open_tree_item('Schemas')
        self.page.toggle_open_tree_item('public')
        self.page.toggle_open_tree_item('Tables')
        self.page.select_tree_item(self.test_table_name)

    def _check_xss_in_browser_tree(self):
        print(
            "\n\tChecking the Browser tree for XSS vulnerabilities",
            file=sys.stderr, end=""
        )
        # Fetch the inner html & check for escaped characters
        source_code = self.page.find_by_xpath(
            "//*[@id='tree']"
        ).get_attribute('innerHTML')

        self._check_escaped_characters(
            source_code,
            "&lt;h1&gt;X",
            "Browser tree"
        )

    def _check_xss_in_properties_tab(self):
        print(
            "\n\tChecking the Properties tab for XSS vulnerabilities",
            file=sys.stderr, end=""
        )
        self.page.click_tab("Properties")
        source_code = self.page.find_by_xpath(
            "//span[contains(@class,'uneditable-input')]"
        ).get_attribute('innerHTML')
        self._check_escaped_characters(
            source_code,
            "&lt;h1&gt;X",
            "Properties tab (Backform Control)"
        )

    def _check_xss_in_sql_tab(self):
        print(
            "\n\tChecking the SQL tab for for XSS vulnerabilities",
            file=sys.stderr, end=""
        )
        self.page.click_tab("SQL")
        # Fetch the inner html & check for escaped characters
        source_code = self.page.find_by_xpath(
            "//*[contains(@class,'CodeMirror-lines') and "
            "contains(.,'CREATE TABLE')]"
        ).get_attribute('innerHTML')

        self._check_escaped_characters(
            source_code,
            "&lt;h1&gt;X",
            "SQL tab (Code Mirror)"
        )

    # Create any constraint with xss name to test this
    def _check_xss_in_dependents_tab(self):

        print(
            "\n\tChecking the Dependents tab for XSS vulnerabilities",
            file=sys.stderr, end=""
        )
        self.page.click_tab("Dependents")

        source_code = self.page.find_by_xpath(
            "//*[@id='5']/table/tbody/tr/td/div/div/div[2]/"
            "table/tbody/tr/td[2]"
        ).get_attribute('innerHTML')

        self._check_escaped_characters(
            source_code,
            "public.&lt;h1 onmouseover='console.log(2);'&gt;Y",
            "Dependents tab (BackGrid)"
        )

    def _check_xss_in_query_tool(self):
        print(
            "\n\tChecking the SlickGrid cell for XSS vulnerabilities",
            file=sys.stderr, end=""
        )
        self.page.fill_codemirror_area_with(
            "select '<img src=\"x\" onerror=\"console.log(1)\">'"
        )
        self.page.find_by_id("btn-flash").click()

        result_row = self.page.find_by_xpath(
            "//*[contains(@class, 'ui-widget-content') and "
            "contains(@style, 'top:0px')]"
        )

        cells = result_row.find_elements_by_tag_name('div')

        # remove first element as it is row number.
        source_code = cells[1].get_attribute('innerHTML')

        self._check_escaped_characters(
            source_code,
            '&lt;img src="x" onerror="console.log(1)"&gt;',
            "Query tool (SlickGrid)"
        )

    def _check_xss_in_query_tool_history(self):
        print(
            "\n\tChecking the Query Tool history for XSS vulnerabilities... ",
            file=sys.stderr, end=""
        )
        self.page.fill_codemirror_area_with(
            "select '<script>alert(1)</script>"
        )
        self.page.find_by_css_selector(
            QueryToolLocators.btn_execute_query_css).click()

        self.page.click_tab('Query History')

        # Check for history entry
        history_ele = self.page.find_by_css_selector(
            ".query-history div.query-group:first-child"
            " .list-item.selected .query"
        )

        source_code = history_ele.get_attribute('innerHTML')

        self._check_escaped_characters(
            source_code,
            '&lt;script&gt;alert(1)&lt;/script&gt;',
            "Query tool (History Entry)"
        )

        # Check for history details message
        history_ele = self.driver\
            .find_element_by_css_selector(".query-detail .content-value")

        source_code = history_ele.get_attribute('innerHTML')

        self._check_escaped_characters(
            source_code,
            '&lt;script&gt;alert(1)&lt;/script&gt;',
            "Query tool (History Details-Message)"
        )
        retry = 2
        while retry > 0:
            try:
                # Check for history details error message
                history_ele = self.page.find_by_css_selector(
                    ".query-detail .history-error-text"
                )
                source_code = history_ele.get_attribute('innerHTML')
                break
            except StaleElementReferenceException:
                retry -= 1

        self._check_escaped_characters(
            source_code,
            '&lt;script&gt;alert(1)&lt;/script&gt;',
            "Query tool (History Details-Error)"
        )

        self.page.click_tab('Query Editor')

    def _check_xss_view_data(self):
        print(
            "\n\tChecking the SlickGrid cell for XSS vulnerabilities",
            file=sys.stderr, end=""
        )

        self.page.find_by_css_selector(".slick-header-column")
        cells = self.driver.\
            find_elements_by_css_selector(".slick-header-column")

        # remove first element as it is row number.
        # currently 4th col
        source_code = cells[4].get_attribute('innerHTML')

        self._check_escaped_characters(
            source_code,
            '&lt;script&gt;alert(1)&lt;/script&gt;',
            "View Data (SlickGrid)"
        )

    def _check_xss_in_explain_module(self):
        print(
            "\n\tChecking the Graphical Explain plan for XSS vulnerabilities",
            file=sys.stderr, end=""
        )
        self.page.fill_codemirror_area_with(
            'select * from "{0}"'.format(self.test_table_name)
        )

        self.page.find_by_css_selector(
            QueryToolLocators.btn_explain).click()
        self.page.wait_for_query_tool_loading_indicator_to_disappear()
        self.page.click_tab('Explain')

        for idx in range(3):
            # Re-try logic
            try:
                ActionChains(self.driver).move_to_element(
                    self.driver.find_element_by_css_selector(
                        'div.pgadmin-explain-container > svg > g > g > image'
                    )
                ).perform()
                break
            except Exception as e:
                if idx != 2:
                    continue
                else:
                    print(
                        "\n\tUnable to locate the explain container to check"
                        " the image tooltip for XSS",
                        file=sys.stderr, end=""
                    )
                    raise

        source_code = self.driver.find_element_by_id(
            'toolTip').get_attribute('innerHTML')

        self._check_escaped_characters(
            source_code,
            "&lt;h1&gt;X",
            "Explain tab (Graphical explain plan)"
        )

    def _check_escaped_characters(self, source_code, string_to_find, source):
        # For XSS we need to search against element's html code
        assert source_code.find(string_to_find) != - \
            1, "{0} might be vulnerable to XSS ".format(source)
