PostgreSQLDBAdapter.java

/*******************************************************************************
 * Copyright (c) 2019, RISE AB
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without 
 * modification, are permitted provided that the following conditions 
 * are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice, 
 *    this list of conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above copyright notice, 
 *    this list of conditions and the following disclaimer in the documentation 
 *    and/or other materials provided with the distribution.
 *
 * 3. Neither the name of the copyright holder nor the names of its
 *    contributors may be used to endorse or promote products derived from
 *    this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 
 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 
 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 
 * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 
 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 
 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 
 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 
 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *******************************************************************************/
package se.sics.ace.examples;

import se.sics.ace.AceException;
import se.sics.ace.as.DBConnector;

import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;
import java.sql.Statement;
import java.sql.ResultSet;
import java.util.Properties;

/**
 * This class handles proper PostgreSQL Db SQL.
 *
 * @author Sebastian Echeverria
 *
 */
public class PostgreSQLDBAdapter implements SQLDBAdapter {

    /**
     * The default admin-user name
     */
    
    /**
     * The default database name
     */
    public static final String BASE_DB = "postgres";

    /**
     * The default connection URL for the database.
     */
    public static final String DEFAULT_DB_URL 
        = "jdbc:postgresql://localhost:5432";

    protected String user;
    protected String password;
    protected String baseDbUrl;
    protected String dbName;

    @Override
    public void setParams(String user, String pwd, String dbName, String dbUrl) {
        this.user = user;
        if(this.user == null)
        {
            this.user = DBConnector.DEFAULT_USER;
        }
        this.password = pwd;
        if(this.password == null)
        {
            this.password = DBConnector.DEFAULT_PASSWORD;
        }
        this.dbName = dbName;
        if(this.dbName == null)
        {
            this.dbName = DBConnector.DEFAULT_DB_NAME;
        }
        this.baseDbUrl = dbUrl;
        if(this.baseDbUrl == null)
        {
            this.baseDbUrl = DEFAULT_DB_URL;
        }
    }

    @Override
    public Connection getAdminConnection(String adminUser, String adminPwd) throws SQLException {
        Properties connectionProps = new Properties();
        connectionProps.put("user", adminUser);
        connectionProps.put("password", adminPwd);
        return DriverManager.getConnection(this.baseDbUrl + "/" 
                + PostgreSQLDBAdapter.BASE_DB, connectionProps);
    }

    @Override
    public Connection getDBConnection() throws SQLException {
        Properties connectionProps = new Properties();
        connectionProps.put("user", this.user);
        connectionProps.put("password", this.password);
        return DriverManager.getConnection(this.baseDbUrl + "/" + this.dbName, connectionProps);
    }

    @Override
    public synchronized void createUser(String adminUser, String adminPwd) throws AceException {
        String createUser = "DO\n" +
                "$body$\n" +
                "BEGIN\n" +
                "   IF NOT EXISTS (\n" +
                "      SELECT *\n" +
                "      FROM   pg_catalog.pg_user\n" +
                "      WHERE  usename = '" +  this.user + "') THEN\n" +
                "\n" +
                "      CREATE ROLE " +  this.user + " LOGIN PASSWORD '" 
                    +  this.password + "';\n" +
                "   END IF;\n" +
                "END\n" +
                "$body$;";

        try (Connection adminConn = getAdminConnection(adminUser, adminPwd);
             Statement stmt = adminConn.createStatement())
        {
            stmt.execute(createUser);
        } catch (SQLException e) {
            e.printStackTrace();
            throw new AceException(e.getMessage());
        }
    }

    @Override
    public synchronized void createDBAndTables(String adminUser, String adminPwd) 
            throws AceException {
        // First check it DB exists.
        String checkDB = "SELECT datname FROM pg_catalog.pg_database "
                + "WHERE datname = '" + this.dbName + "';";
        try (Connection adminConn = getAdminConnection(adminUser, adminPwd);
             Statement stmt = adminConn.createStatement();
             ResultSet result = stmt.executeQuery(checkDB))
        {
            if (result.next())
            {
                // Treat this as a "create if not exist", so if it exists, end method without doing anything.
                return;
            }
        } catch (SQLException e) {
            e.printStackTrace();
            throw new AceException(e.getMessage());
        }

        // Create the database.
        String createDB = "CREATE DATABASE " + this.dbName
                + " WITH OWNER= " + this.user 
                + " ENCODING = 'UTF8' TEMPLATE = template0 " 
                + " CONNECTION LIMIT = -1;";
        try (Connection adminConn = getAdminConnection(adminUser, adminPwd);
             Statement stmt = adminConn.createStatement())
        {
            stmt.execute(createDB);
        } catch (SQLException e) {
            e.printStackTrace();
            throw new AceException(e.getMessage());
        }

        //rs id, cose encoding, default expiration time, psk, rpk
        String createRs = "CREATE TABLE " +  DBConnector.rsTable + "("
                + DBConnector.rsIdColumn + " varchar(255) NOT NULL, "
                + DBConnector.expColumn + " bigint NOT NULL, "
                + DBConnector.tokenPskColumn + " bytea, "
                + DBConnector.authPskColumn + " bytea, "        
                + DBConnector.rpkColumn + " bytea,"
                + "PRIMARY KEY (" + DBConnector.rsIdColumn + "));";

        String createC = "CREATE TABLE " +  DBConnector.cTable + " ("
                + DBConnector.clientIdColumn + " varchar(255) NOT NULL, "
                + DBConnector.defaultAud + " varchar(255), "
                + DBConnector.defaultScope + " varchar(255), "
                + DBConnector.authPskColumn + " bytea, "
                + DBConnector.rpkColumn + " bytea,"
                + "PRIMARY KEY (" + DBConnector.clientIdColumn + "));";

        String createProfiles = "CREATE TABLE "
                + DBConnector.profilesTable + "("
                + DBConnector.idColumn + " varchar(255) NOT NULL, "
                + DBConnector.profileColumn + " varchar(255) NOT NULL);";

        String keyType = "CREATE TYPE keytype AS ENUM ('PSK', 'RPK', 'TST');";

        String createKeyTypes = "CREATE TABLE "
                + DBConnector.keyTypesTable + "("
                + DBConnector.idColumn + " varchar(255) NOT NULL, "
                + DBConnector.keyTypeColumn + " keytype);";

        String createScopes = "CREATE TABLE "
                + DBConnector.scopesTable + "("
                + DBConnector.rsIdColumn + " varchar(255) NOT NULL, "
                + DBConnector.scopeColumn + " varchar(255) NOT NULL);";

        String tokenType 
            = "CREATE TYPE tokenType AS ENUM ('CWT', 'REF', 'TST');";

        String createTokenTypes = "CREATE TABLE "
                + DBConnector.tokenTypesTable + "("
                + DBConnector.rsIdColumn + " varchar(255) NOT NULL, "
                + DBConnector.tokenTypeColumn + " tokenType);";

        String createAudiences = "CREATE TABLE "
                + DBConnector.audiencesTable + "("
                + DBConnector.rsIdColumn + " varchar(255) NOT NULL, "
                + DBConnector.audColumn + " varchar(255) NOT NULL);";

        String createCose =  "CREATE TABLE "
                + DBConnector.coseTable + "("
                + DBConnector.rsIdColumn + " varchar(255) NOT NULL, "
                + DBConnector.coseColumn + " varchar(255) NOT NULL);";

        String createClaims = "CREATE TABLE "
                + DBConnector.claimsTable + "("
                + DBConnector.ctiColumn + " varchar(255) NOT NULL, "
                + DBConnector.claimNameColumn + " SMALLINT NOT NULL,"
                + DBConnector.claimValueColumn + " bytea);";
        
        String createOldTokens = "CREATE TABLE "
                + DBConnector.oldTokensTable + "("
                + DBConnector.ctiColumn + " varchar(255) NOT NULL, "
                + DBConnector.claimNameColumn + " SMALLINT NOT NULL,"
                + DBConnector.claimValueColumn + " bytea);";
        
        String createCtiCtr = "CREATE TABLE "
                + DBConnector.ctiCounterTable + "("
                + DBConnector.ctiCounterColumn + " bigint);";

        String initCtiCtr = "INSERT INTO "
                + DBConnector.ctiCounterTable
                + " VALUES (0);";

        String createTokenLog = "CREATE TABLE "
                + DBConnector.cti2clientTable + "("
                + DBConnector.ctiColumn + " varchar(255) NOT NULL, "
                + DBConnector.clientIdColumn + " varchar(255) NOT NULL,"
                + " PRIMARY KEY (" + DBConnector.ctiColumn + "));";
        
        String createGrant2Cti = "CREATE TABLE "
                + DBConnector.grant2ctiTable + "("
                + DBConnector.grantColumn + " varchar(255) NOT NULL, "
                + DBConnector.ctiColumn + " varchar(255) NOT NULL, "
                + DBConnector.grantValidColumn + " BOOLEAN DEFAULT TRUE, "
                + " PRIMARY KEY (" + DBConnector.grantColumn + ","
                + DBConnector.ctiColumn + "));";
            
        String createGrant2RSInfo = "CREATE TABLE "
                + DBConnector.grant2RSInfoTable + "("
                + DBConnector.grantColumn + " varchar(255) NOT NULL, "
                + DBConnector.claimNameColumn + " SMALLINT NOT NULL,"
                + DBConnector.claimValueColumn + " bytea);";

        // Table creation in PostgreSQL needs to be done with a connection 
        //using the local user and not the admin user, so that the local 
        //user will be automatically set as the owner of the tables.
        try (Connection adminConn = getDBConnection();
             Statement stmt = adminConn.createStatement())
        {
            stmt.execute(createRs);
            stmt.execute(createC);
            stmt.execute(createProfiles);
            stmt.execute(keyType);
            stmt.execute(createKeyTypes);
            stmt.execute(createScopes);
            stmt.execute(tokenType);
            stmt.execute(createTokenTypes);
            stmt.execute(createAudiences);
            stmt.execute(createCose);
            stmt.execute(createClaims);
            stmt.execute(createOldTokens);
            stmt.execute(createCtiCtr);
            stmt.execute(initCtiCtr);
            stmt.execute(createTokenLog);
            stmt.execute(createGrant2Cti);
            stmt.execute(createGrant2RSInfo);
        } catch (SQLException e) {
            e.printStackTrace();
            throw new AceException(e.getMessage());
        }
    }

    @Override
    public String updateEngineSpecificSQL(String sqlQuery)
    {
        // In PostgreSQL, enums need casting.
        if(sqlQuery.contains("INSERT") && sqlQuery.contains(
                DBConnector.keyTypesTable)) {
            return "INSERT INTO " + DBConnector.keyTypesTable 
                    + " VALUES (?,?::keytype)";
        }
        if(sqlQuery.contains("INSERT") && sqlQuery.contains(
                DBConnector.tokenTypesTable)) {
            return "INSERT INTO " + DBConnector.tokenTypesTable 
                    + " VALUES (?,?::tokentype)";
        }

        // Create table statements do not take the db name in PostgreSQL.
        if (sqlQuery.contains("CREATE TABLE")) {
           String ret = sqlQuery;
           if (sqlQuery.contains(this.dbName + ".")) {
               ret = sqlQuery.replace(this.dbName + ".", "");
           }
           return ret;
        }        
        return sqlQuery;
    }

    @Override
    public void wipeDB(String adminUser, String adminPwd) throws AceException
    {
        try (Connection adminConn = getAdminConnection(adminUser, adminPwd);
             Statement stmt = adminConn.createStatement())
        {
            String dropConnections = "SELECT pg_terminate_backend(pg_stat_activity.pid) " 
                    + "    FROM pg_stat_activity " 
                    + "    WHERE pg_stat_activity.datname = '" 
                    + this.dbName + "'" 
                    + "      AND pid <> pg_backend_pid();";
            String dropDB = "DROP DATABASE IF EXISTS " + this.dbName + ";";
            String dropUser = "DROP USER IF EXISTS " + this.user + ";";
            stmt.execute(dropConnections);
            stmt.execute(dropDB);
            stmt.execute(dropUser);
        } catch (SQLException e) {
            throw new AceException(e.getMessage());
        }
    }
}