MysqlPassword.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.mysql;

import org.apache.doris.common.AnalysisException;
import org.apache.doris.common.ErrorCode;
import org.apache.doris.common.ErrorReport;
import org.apache.doris.qe.GlobalVariable;

import com.google.common.base.Strings;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.io.UnsupportedEncodingException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.Random;
import java.util.Set;
import java.util.stream.Collectors;

// this is stolen from MySQL
//
// The main idea is that no password are sent between client & server on
// connection and that no password are saved in mysql in a decodable form.
//
// On connection a random string is generated and sent to the client.
// The client generates a new string with a random generator inited with
// the hash values from the password and the sent string.
// This 'check' string is sent to the server where it is compared with
// a string generated from the stored hash_value of the password and the
// random string.
//
// The password is saved (in user.password) by using the PASSWORD() function in
// mysql.
//
// This is .c file because it's used in libmysqlclient, which is entirely in C.
// (we need it to be portable to a variety of systems).
// Example:
// update user set password=PASSWORD("hello") where user="test"
// This saves a hashed number as a string in the password field.
//
// The new authentication is performed in following manner:
//
// SERVER:  public_seed=create_random_string()
//          send(public_seed)
//
// CLIENT:  recv(public_seed)
//          hash_stage1=sha1("password")
//          hash_stage2=sha1(hash_stage1)
//          reply=xor(hash_stage1, sha1(public_seed,hash_stage2)
//
//          this three steps are done in scramble()
//
//          send(reply)
//
// SERVER:  recv(reply)
//          hash_stage1=xor(reply, sha1(public_seed,hash_stage2))
//          candidate_hash2=sha1(hash_stage1)
//          check(candidate_hash2==hash_stage2)
//
//          this three steps are done in check_scramble()
public class MysqlPassword {
    private static final Logger LOG = LogManager.getLogger(MysqlPassword.class);
    // TODO(zhaochun): this is duplicated with handshake packet.
    public static final byte[] EMPTY_PASSWORD = new byte[0];
    public static final int SCRAMBLE_LENGTH = 20;
    public static final int SCRAMBLE_LENGTH_HEX_LENGTH = 2 * SCRAMBLE_LENGTH + 1;
    public static final byte PVERSION41_CHAR = '*';
    private static final byte[] DIG_VEC_UPPER = {'0', '1', '2', '3', '4', '5', '6', '7',
            '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'};
    private static final Random random = new SecureRandom();
    private static final Set<Character> complexCharSet;
    public static final int MIN_PASSWORD_LEN = 8;

    static {
        complexCharSet = "~!@#$%^&*()_+|<>,.?/:;'[]{}".chars().mapToObj(c -> (char) c).collect(Collectors.toSet());
    }

    public static byte[] createRandomString(int len) {
        byte[] bytes = new byte[len];
        random.nextBytes(bytes);
        // NOTE: MySQL challenge string can't contain 0.
        for (int i = 0; i < len; ++i) {
            if (!((bytes[i] >= 'a' && bytes[i] <= 'z')
                    || (bytes[i] >= 'A' && bytes[i] <= 'Z'))) {
                bytes[i] = (byte) ('a' + (bytes[i] % 26));
            }
        }
        return bytes;
    }

    private static byte[] xorCrypt(byte[] s1, byte[] s2) {
        if (s1.length != s2.length) {
            return null;
        }
        byte[] res = new byte[s1.length];
        for (int i = 0; i < s1.length; ++i) {
            res[i] = (byte) (s1[i] ^ s2[i]);

        }
        return res;
    }

    // Check that scrambled message corresponds to the password; the function
    // is used by server to check that received reply is authentic.
    // This function does not check lengths of given strings: message must be
    // null-terminated, reply and hash_stage2 must be at least SHA1_HASH_SIZE
    // long (if not, something fishy is going on).
    // SYNOPSIS
    //   check_scramble_sha1()
    //   scramble     clients' reply, presumably produced by scramble()
    //   message      original random string, previously sent to client
    //                (presumably second argument of scramble()), must be
    //                exactly SCRAMBLE_LENGTH long and NULL-terminated.
    //   hash_stage2  hex2octet-decoded database entry
    //   All params are IN.
    //
    // RETURN VALUE
    //   0  password is correct
    //   !0  password is invalid
    public static boolean checkScramble(byte[] scramble, byte[] message, byte[] hashStage2) {
        MessageDigest md = null;
        try {
            md = MessageDigest.getInstance("SHA-1");
        } catch (NoSuchAlgorithmException e) {
            LOG.warn("No SHA-1 Algorithm when compute password.");
            return false;
        }
        // compute result1: XOR(scramble, SHA-1 (public_seed + hashStage2))
        md.update(message);
        md.update(hashStage2);
        byte[] hashStage1 = xorCrypt(md.digest(), scramble);

        // compute result2: SHA-1(result1)
        md.reset();
        md.update(hashStage1);
        byte[] candidateHash2 = md.digest();
        // compare result2 and hashStage2 using MessageDigest.isEqual()
        return MessageDigest.isEqual(candidateHash2, hashStage2);
    }

    // MySQL client use this function to form scramble password
    // password: plaintext password
    public static byte[] scramble(byte[] seed, String password) {
        byte[] scramblePassword = null;
        try {
            byte[] passBytes = password.getBytes("UTF-8");
            MessageDigest md = MessageDigest.getInstance("SHA-1");
            byte[] hashStage1 = md.digest(passBytes);
            md.reset();
            byte[] hashStage2 = md.digest(hashStage1);
            md.reset();
            md.update(seed);
            scramblePassword = xorCrypt(hashStage1, md.digest(hashStage2));
        } catch (UnsupportedEncodingException e) {
            // no UTF-8 character set
            LOG.warn("No UTF-8 character set when compute password.");
        } catch (NoSuchAlgorithmException e) {
            // No SHA-1 algorithm
            LOG.warn("No SHA-1 Algorithm when compute password.");
        }

        return scramblePassword;
    }

    // Convert plaintext password into the corresponding 2-staged hashed password
    // Used for users to set password
    private static byte[] twoStageHash(String password) {
        try {
            byte[] passBytes = password.getBytes("UTF-8");
            MessageDigest md = MessageDigest.getInstance("SHA-1");
            byte[] hashStage1 = md.digest(passBytes);
            md.reset();
            byte[] hashStage2 = md.digest(hashStage1);

            return hashStage2;
        } catch (UnsupportedEncodingException e) {
            // no UTF-8 character set
            LOG.warn("No UTF-8 character set when compute password.");
        } catch (NoSuchAlgorithmException e) {
            // No SHA-1 algorithm
            LOG.warn("No SHA-1 Algorithm when compute password.");
        }

        return null;
    }

    // covert octet 'from' to hex 'to'
    // NOTE: this function assume that to buffer is enough
    private static void octetToHexSafe(byte[] to, int toOff, byte[] from) {
        int j = toOff;
        for (int i = 0; i < from.length; i++) {
            int val = from[i] & 0xff;
            to[j++] = DIG_VEC_UPPER[val >> 4];
            to[j++] = DIG_VEC_UPPER[val & 0x0f];
        }
    }

    private static int fromByte(int b) {
        return (b >= '0' && b <= '9') ? b - '0'
                : (b >= 'A' && b <= 'F') ? b - 'A' + 10 : b - 'a' + 10;
    }

    // covert hex 'from' to octet 'to'
    // fromOff: offset of 'from' to covert, there is no pointer in JAVA
    // NOTE: this function assume that to buffer is enough
    private static void hexToOctetSafe(byte[] to, byte[] from, int fromOff) {
        int j = 0;
        for (int i = fromOff; i < from.length; i++) {
            int val = fromByte(from[i++] & 0xff);
            to[j++] = ((byte) ((val << 4) + fromByte(from[i] & 0xff)));
        }
    }

    // Make password which stored in palo meta from plain text
    public static byte[] makeScrambledPassword(String plainPasswd) {
        if (Strings.isNullOrEmpty(plainPasswd)) {
            return EMPTY_PASSWORD;
        }

        byte[] hashStage2 = twoStageHash(plainPasswd);
        byte[] passwd = new byte[SCRAMBLE_LENGTH_HEX_LENGTH];
        passwd[0] = (PVERSION41_CHAR);
        octetToHexSafe(passwd, 1, hashStage2);
        return passwd;
    }

    // Convert scrambled password from ascii hex string to binary form.
    public static byte[] getSaltFromPassword(byte[] password) {
        if (password == null || password.length == 0) {
            return EMPTY_PASSWORD;
        }
        byte[] hashStage2 = new byte[SCRAMBLE_LENGTH];
        hexToOctetSafe(hashStage2, password, 1);
        return hashStage2;
    }

    public static boolean checkPlainPass(byte[] scrambledPass, String plainPass) {
        byte[] pass = makeScrambledPassword(plainPass);
        if (pass.length != scrambledPass.length) {
            return false;
        }
        for (int i = 0; i < pass.length; ++i) {
            if (pass[i] != scrambledPass[i]) {
                return false;
            }
        }
        return true;
    }

    public static byte[] checkPassword(String passwdString) throws AnalysisException {
        if (Strings.isNullOrEmpty(passwdString)) {
            return EMPTY_PASSWORD;
        }

        byte[] passwd = null;
        try {
            passwdString = passwdString.toUpperCase();
            passwd = passwdString.getBytes("UTF-8");
        } catch (UnsupportedEncodingException e) {
            ErrorReport.reportAnalysisException(ErrorCode.ERR_UNKNOWN_ERROR);
        }
        if (passwd.length != SCRAMBLE_LENGTH_HEX_LENGTH || passwd[0] != PVERSION41_CHAR) {
            ErrorReport.reportAnalysisException(ErrorCode.ERR_PASSWD_LENGTH, 41);
        }

        for (int i = 1; i < passwd.length; ++i) {
            if (!((passwd[i] <= '9' && passwd[i] >= '0') || passwd[i] >= 'A' && passwd[i] <= 'F')) {
                ErrorReport.reportAnalysisException(ErrorCode.ERR_PASSWD_LENGTH, 41);
            }
        }

        return passwd;
    }

    public static void validatePlainPassword(long validaPolicy, String text) throws AnalysisException {
        if (validaPolicy == GlobalVariable.VALIDATE_PASSWORD_POLICY_STRONG) {
            if (Strings.isNullOrEmpty(text) || text.length() < MIN_PASSWORD_LEN) {
                throw new AnalysisException(
                        "Violate password validation policy: STRONG. The password must be at least 8 characters");
            }

            int i = 0;
            if (text.chars().anyMatch(Character::isDigit)) {
                i++;
            }
            if (text.chars().anyMatch(Character::isLowerCase)) {
                i++;
            }
            if (text.chars().anyMatch(Character::isUpperCase)) {
                i++;
            }
            if (text.chars().anyMatch(c -> complexCharSet.contains((char) c))) {
                i++;
            }
            if (i < 3) {
                throw new AnalysisException(
                        "Violate password validation policy: STRONG. The password must contain at least 3 types of "
                                + "numbers, uppercase letters, lowercase letters and special characters.");
            }
        }
    }
}