UPSIDER Tech Blog

作って学ぶMySQLプロトコル

カード事業部でバックエンドエンジニアをしている Mimura です。弊社ではマイクロサービスごとに要件や特性に合ったデータベースが選定されますが、私の関わるサービスではMySQLが選定されることが多いです。

普段は database/sql や ORMライブラリを通じて当たり前のようにMySQLとやり取りしていますが、その裏側でどのような通信が流れているのかは意識する機会が多くありません。そこで今回は、MySQLのClient/Server Protocolを自前で実装し、TCPソケットでMySQLサーバーに直接接続してハンドシェイクと認証を行うクライアントをGoで書いてみました。

本記事ではMySQLサーバーに対して接続を確立し、ハンドシェイクと認証を行うクライアントを実装してみます。実装を通じて以下のような学びが得られると思います。

  • バイナリプロトコルの読み書き
  • チャレンジレスポンス認証の仕組み(mysql_native_password)
  • パケットベースの通信プロトコルの理解

準備

バックグラウンドでMySQLコンテナを起動しておきます。

docker run -d \
  --name mysql \
  -p 3306:3306 \
  -e MYSQL_ROOT_PASSWORD=password \
  -e MYSQL_USER=user \
  -e MYSQL_PASSWORD=password \
  mysql:8.0.43 \
  --default-authentication-plugin=mysql_native_password

本記事執筆時点での最新LTSはMySQL 8.4ですが、簡単に mysql_native_password プラグインで認証にするために今回はMySQL 8.0.43を使用します1。以降では、このMySQLコンテナに対してクライアントを実装し、通信を試みます。

MySQLプロトコルの基礎知識

MySQLがクライアントとサーバー間の通信を行う際には、その名の通りのMySQLプロトコルが使用されます。MySQLプロトコルはTCPの上でパケットという単位でデータがやり取りされます。パケットは以下のような構造をしています。

Type Name Description
int<3> payload_length ペイロードの長さ
int<1> sequence_id シーケンス番号
string<var> payload 実際のデータ

ここで、int<num>num バイトの整数を表し、string<var> は可変長の文字列を表します。 それぞれ役割は名前の通りで、先頭の4バイトがヘッダーとして扱われます。すなわち、パケットのやり取りでは先頭4バイトのヘッダーを読んでpayload_lengthを取得し、その後に続くデータをpayloadとして読み取ることになりそうだなと分かります。

例えば、01 00 00 00 01 というパケットは、payload_lengthが1、sequence_idが0、payloadが 01 であることを表します。(payload_length は Little endianで表されていることに注意)

基本となるデータ構造を理解したところで、MySQLプロトコルの具体的な通信を見ていきます。MySQLの通信は Connection Phase と Command Phase の2つのフェーズからなります。

Connection Phase

Connection Phaseは、クライアントとサーバーが接続を確立し認証を行うフェーズです。このフェーズは以下の3つのステップで構成されます。2

1. Initial Handshake

クライアントがサーバーに TCP接続を確立すると、サーバーは最初に Initial Handshake Packet を送信してきます。このパケットには、クライアントが認証を行うために必要な情報が含まれています。

Initial Handshake Packetの主な構成要素は以下の通りです:

Type Name Description
int<1> protocol_version プロトコルバージョン(通常は10)
string<NUL> server_version MySQLのバージョン文字列
int<4> connection_id この接続に割り当てられた一意のID
string<8> auth_plugin_data_part_1 認証用の scramble(前半8バイト)
int<1> filler 固定値 0x00
int<2> capability_flags_1 サーバーの機能フラグ(下位2バイト)
int<1> character_set サーバーのデフォルト文字セット
int<2> status_flags サーバーのステータス
int<2> capability_flags_2 サーバーの機能フラグ(上位2バイト)
int<1> auth_plugin_data_len auth_plugin_dataの長さ
string<10> reserved 予約領域(0埋め)
string<$len> auth_plugin_data_part_2 認証用の scramble(後半部分)。ここで、$len=MAX(13, length of auth_plugin_data - 8) である。
string<NUL> auth_plugin_name 使用する認証プラグイン名(今回は mysql_native_password

ここで、string<_NUL_>はNULが出現するまでの文字列長を表します。

たくさんフィールドがありますが、Handshake Response Packet を作成するために必要なのはauth_plugin_data_part_1, auth_plugin_data_part_2, auth_plugin_name だけなので、あまり身構えなくても大丈夫です(後述)。

2. Handshake Response

クライアントは Initial Handshake Packet の情報を基に、Handshake Response Packet を作成してサーバーに送信します。このパケットには、ユーザー名、認証トークン、使用する認証プラグイン名などが含まれます。

Handshake Response Packetの主な構成要素:

Type Name Description
int<4> capability_flags クライアントがサポートする機能フラグ
int<4> max_packet_size クライアントが受信可能な最大パケットサイズ
int<1> character_set クライアントの文字セット
string<23> reserved 予約領域(0埋め)
string<NUL> username ユーザー名
int<1> auth_response_length 認証トークンの長さ
string<lenenc> auth_response 認証トークン
string<NUL> client_plugin_name 使用する認証プラグイン名

認証トークン(auth_response)は、パスワードとInitial Handshake Packet で受信した scramble を使って生成されます。mysql_native_password 認証プラグインの場合、以下のように生成します。

SHA1(password) XOR SHA1(scramble + SHA1(SHA1(password)))

3. Authentication Response

Handshake Response Packet を受信したサーバーは、認証を試みます。認証が成功すれば OK Packet を、失敗すれば ERR Packet を返します。

  • OK Packet: 先頭バイトが 0x00 で、認証成功を表します
  • ERR Packet: 先頭バイトが 0xFF で、エラーコードとエラーメッセージが含まれます

OK Packet を受信したら、Connection Phase は完了し、Command Phase に移行します。

Command Phase

サーバーに対してコマンドを送信し、結果を取得するフェーズです。 クエリの実行は COM_QUERY command code(0x03)とその後にクエリ文字列を送るだけのシンプルなものです。 クエリの結果は可変長のため、上手く読み取りとパース処理する必要があります。

本記事ではサーバーとの接続までを実装するため、Command Phase についての詳細は省略します。気になる方はこちらのドキュメントをご覧ください。

Connection Phase の実装

それでは、実際にConnection Phase を実装していきます。

先のセクションでパケットの構造について確認しましたが、百聞は一見に如かずということで、Initial Handshake Packet を受信してパケットの中身を見てみます。

conn, err := net.Dial("tcp", "localhost:3306")
if err != nil {
    log.Fatal(err)
}
defer conn.Close()

header := make([]byte, 4)
if _, err := io.ReadFull(conn, header); err != nil {
    log.Fatal(err)
}
fmt.Println("[header]")
fmt.Printf("%s\n", hex.Dump(header))

// Little endian
length := int(header[0]) | int(header[1])<<8 | int(header[2])<<16
fmt.Println("[length]", length)

payload := make([]byte, length)
if _, err := io.ReadFull(conn, payload); err != nil {
    log.Fatal(err)
}
fmt.Println("[payload]")
fmt.Printf("%s\n", hex.Dump(payload))

実行すると以下のような結果が得られます。

[header]
00000000  4a 00 00 00                                       |J...|

[length] 74
[payload]
00000000  0a 38 2e 30 2e 34 33 00  0b 00 00 00 36 1d 79 54  |.8.0.43.....6.yT|
00000010  04 6a 74 3a 00 ff ff ff  02 00 ff df 15 00 00 00  |.jt:............|
00000020  00 00 00 00 00 00 00 2b  5b 1b 66 25 42 1c 61 41  |.......+[.f%B.aA|
00000030  46 50 72 00 6d 79 73 71  6c 5f 6e 61 74 69 76 65  |FPr.mysql_native|
00000040  5f 70 61 73 73 77 6f 72  64 00                    |_password.|

Payloadの先頭からプロトコルバージョンが10、MySQLのバージョンは8.0.43、最後に認証プラグイン名など、先程の説明で出てきたものがちらほらと見えています。ドキュメントと実態が対応していそうな雰囲気を感じられて嬉しいですね。これらをパースしていきましょう。

Initial Handshake Packetのパース

受信したpayloadから必要な情報を抽出していきます。パケットの構造に従って、バイト列を順番に読み取っていきます。

pos := 0

// protocol_version (1 byte)
protocolVersion := payload[pos]
pos += 1
fmt.Println("[protocolVersion]", protocolVersion)

// server_version (NUL-terminated string)
idx := bytes.IndexByte(payload[pos:], 0x00)
if idx == -1 {
    log.Fatal("NUL not found")
}
serverVersion := payload[pos : pos+idx]
fmt.Printf("[server_version] %s\n", serverVersion)
pos += idx + 1

// connection_id (4 bytes)
connectionID := binary.LittleEndian.Uint32(payload[pos : pos+4])
fmt.Println("[connectionID]", connectionID)
pos += 4

// auth_plugin_data_part_1 (8 bytes)
authPluginDataPart1 := payload[pos : pos+8]
pos += 8

// filler (1 byte)
pos += 1

// capability_flags (2 + 2 bytes)
capabilityFlags1 := payload[pos : pos+2]
pos += 2

// character_set (1 byte)
charset := payload[pos]
pos += 1

// status_flags (2 bytes)
statusFlags := binary.LittleEndian.Uint16(payload[pos : pos+2])
pos += 2

// capability_flags_2 (2 bytes)
capabilityFlags2 := payload[pos : pos+2]
pos += 2

// auth_plugin_data_len (1 byte)
authPluginDataLen := int(payload[pos])
pos += 1

// reserved (10 bytes)
pos += 10

// auth_plugin_data_part_2
authPluginDataPart2Len := max(13, authPluginDataLen-8)
authPluginDataPart2 := payload[pos : pos+authPluginDataPart2Len]
pos += authPluginDataPart2Len

// auth_plugin_name (NUL-terminated string)
idx2 := bytes.IndexByte(payload[pos:], 0x00)
if idx2 == -1 {
    log.Fatal("NUL not found")
}
authPluginName := payload[pos : pos+idx2]
fmt.Printf("[authPluginName] %s\n", authPluginName)

ポイントは以下の通りです:

  • NUL終端文字列: string<NUL>はNULが出現するまでの文字列を取得するため、bytes.IndexByte0x00の位置を探す
  • Little Endian: binary.LittleEndianを使って多バイト整数を読み取る

実行結果:

[protocolVersion] 10
[server_version] 8.0.43
[connectionID] 11
[authPluginName] mysql_native_password

Handshake Response Packetの作成

次に、サーバーに送信するHandshake Response Packetを作成します。まず、認証トークンを生成します。

認証トークンの生成

改めて、mysql_native_password認証での認証トークンの生成は以下のようになります。

// SHA1(password) XOR SHA1(scramble + SHA1(SHA1(password)))

func sha1Hash(data []byte) []byte {
    hash := sha1.Sum(data)
    return hash[:]
}

// scramble = auth_plugin_data_part_1 + auth_plugin_data_part_2[0:12]
scramble := append(authPluginDataPart1, authPluginDataPart2[0:12]...)

// SHA1(password)
authResponse := sha1Hash([]byte(PASSWORD))

// SHA1(scramble + SHA1(SHA1(password)))
tmp := sha1Hash(append(scramble, sha1Hash(authResponse)...))

// SHA1(password) XOR SHA1(scramble + SHA1(SHA1(password)))
for i := range 20 {
    authResponse[i] ^= tmp[i]
}

このXOR演算により、パスワードをそのまま送信することなく、サーバー側で検証可能な認証トークンが生成されます。

パケットの構築

次に、Handshake Response Packetのバイト列を構築します:

// パケットサイズを計算
func calculateBufferSize(authResponse []byte, authPluginName string) int {
    size := 4                        // header
    size += 4                        // client_flag
    size += 4                        // max_packet_size
    size += 1                        // character_set
    size += 23                       // filler
    size += len(USERNAME) + 1        // username + NUL
    size++                           // auth_response_length
    size += len(authResponse)        // auth_response
    size += len(authPluginName) + 1  // client_plugin_name + NUL
    return size
}

bufsize := calculateBufferSize(authResponse, string(authPluginName))
buf := make([]byte, bufsize)

pos = 0
payloadsize := bufsize - 4 // ヘッダーサイズを引いておく

// Header: payload_length (3 bytes) + sequence_id (1 byte)
buf[0] = byte(payloadsize)
buf[1] = byte(payloadsize >> 8)
buf[2] = byte(payloadsize >> 16)
buf[3] = 0x01  // sequence_id
pos += 4

// capability_flags (4 bytes)
const (
    CLIENT_PROTOCOL_41       = 0x00000200
    CLIENT_SECURE_CONNECTION = 0x00008000
)

// CLIENT_SECURE_CONNECTION は auth_response_length を使用する場合に必要
capabilityFlags := uint32(CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION)
buf[pos] = byte(capabilityFlags)
buf[pos+1] = byte(capabilityFlags >> 8)
buf[pos+2] = byte(capabilityFlags >> 16)
buf[pos+3] = byte(capabilityFlags >> 24)
pos += 4

// max_packet_size (4 bytes) - ゼロのままスキップ
pos += 4

// character_set (1 byte) - 0x33 = utf8_general_ci
buf[pos] = 0x33
pos++

// reserved (23 bytes) - ゼロのままスキップ
pos += 23

// username (NUL-terminated string)
copy(buf[pos:], USERNAME)
pos += len(USERNAME) + 1

// auth_response_length (1 byte)
buf[pos] = byte(len(authResponse))
pos++

// auth_response (20 bytes for SHA1)
copy(buf[pos:], authResponse)
pos += len(authResponse)

// client_plugin_name (NUL-terminated string)
copy(buf[pos:], authPluginName)

Handshake Responseの送信と認証結果の確認

作成したパケットをサーバーに送信し、認証結果を受け取ります。

// Handshake Response を送信
fmt.Printf("[write handshake response packet:]: \n%s", hex.Dump(buf))
if _, err := conn.Write(buf); err != nil {
    log.Fatal(err)
}

// サーバーからの応答を受信
header = make([]byte, 4)
if _, err := io.ReadFull(conn, header); err != nil {
    log.Fatal(err)
}

payloadSize := int(uint32(header[0]) | uint32(header[1])>>8 | uint32(header[2])>>16)
payload = make([]byte, payloadSize)
if _, err := io.ReadFull(conn, payload); err != nil {
    log.Fatal(err)
}

fmt.Println("[payload]")
fmt.Printf("%s\n", hex.Dump(payload))

if payload[0] == 0x00 {
    fmt.Println("\n✅ Successfully authenticated to MySQL server")
} else if payload[0] == 0xFF {
    fmt.Printf("\n❌ Authentication failed: %s\n", string(payload[3:]))
}

実行結果:

[write packets: 63 bytes]: 
00000000  3b 00 00 01 08 02 00 00  00 00 00 01 33 00 00 00  |;...........3...|
00000010  00 00 00 00 00 00 00 00  00 00 00 00 00 00 00 00  |................|
00000020  00 00 00 75 73 65 72 00  14 d8 5e 4a 2d 9a 7f 8e  |...user...^J-...|
00000030  16 7c 42 ea 4b 9f 3d 5a  1f 23 d9 6f 4c 6d 79 73  |.|B.K.=Z.#.oLmys|
00000040  71 6c 5f 6e 61 74 69 76  65 5f 70 61 73 73 77 6f  |ql_native_passwo|
00000050  72 64 00                                          |rd.|

[header]
00000000  07 00 00 02                                       |....|

[payload]
00000000  00 00 00 02 00 00 00                              |.......|

✅ Successfully authenticated to MySQL server

先頭バイトが0x00のOK Packetが返ってくれば認証に成功したことになります。

最終コード

package main

import (
    "bytes"
    "crypto/sha1"
    "encoding/binary"
    "encoding/hex"
    "fmt"
    "io"
    "log"
    "net"
)

const (
    USERNAME           = "user"
    PASSWORD           = "password"
    NUL                = 0x00
    CLIENT_PLUGIN_AUTH = 0x00080000
    CLIENT_PROTOCOL_41 = 0x00000200
)

func main() {
    conn, err := net.Dial("tcp", "localhost:3306")
    if err != nil {
        log.Fatal(err)
    }
    defer conn.Close()

    // ========================================
    // Initial Handshake Packet を受信
    // ========================================
    header := make([]byte, 4)
    if _, err := io.ReadFull(conn, header); err != nil {
        log.Fatal(err)
    }
    fmt.Println("[header]")
    fmt.Printf("%s\n", hex.Dump(header))

    // Little endian でペイロード長を取得
    length := int(header[0]) | int(header[1])<<8 | int(header[2])<<16
    fmt.Println("[length]", length)

    payload := make([]byte, length)
    if _, err := io.ReadFull(conn, payload); err != nil {
        log.Fatal(err)
    }
    fmt.Println("[payload]")
    fmt.Printf("%s\n", hex.Dump(payload))

    // ========================================
    // Initial Handshake Packet をパース
    // ========================================
    pos := 0

    // protocol_version (1 byte)
    protocolVersion := payload[pos]
    pos += 1
    fmt.Println("[protocolVersion]", protocolVersion)

    // server_version (NUL-terminated string)
    idx := bytes.IndexByte(payload[pos:], NUL)
    if idx == -1 {
        log.Fatal("NUL not found")
    }
    serverVersion := payload[pos : pos+idx]
    fmt.Printf("[server_version] %s\n", serverVersion)
    pos += idx + 1

    // connection_id (4 bytes)
    connectionID := binary.LittleEndian.Uint32(payload[pos : pos+4])
    fmt.Println("[connectionID]", connectionID)
    pos += 4

    // auth_plugin_data_part_1 (8 bytes)
    authPluginDataPart1 := payload[pos : pos+8]
    fmt.Println("[authPluginDataPart1]", authPluginDataPart1)
    fmt.Printf("    %s", hex.Dump(authPluginDataPart1))
    pos += 8

    // filler (1 byte)
    pos += 1

    // capability_flags_1 (2 bytes)
    capabilityFlags1 := payload[pos : pos+2]
    fmt.Println("[capabilityFlags1]", capabilityFlags1)
    pos += 2

    // character_set (1 byte)
    charset := payload[pos]
    fmt.Println("[charset]", charset)
    pos += 1

    // status_flags (2 bytes)
    statusFlags := binary.LittleEndian.Uint16(payload[pos : pos+2])
    fmt.Println("[statusFlags]", statusFlags)
    pos += 2

    // capability_flags_2 (2 bytes)
    capabilityFlags2 := payload[pos : pos+2]
    fmt.Println("[capabilityFlags2]", capabilityFlags2)
    pos += 2

    // capabilities を計算
    capabilities := binary.LittleEndian.Uint32(append(capabilityFlags1, capabilityFlags2...))
    fmt.Println("[capabilities]", capabilities)

    // CLIENT_PLUGIN_AUTH のサポートチェック
    if capabilities&CLIENT_PLUGIN_AUTH == 0 {
        log.Fatal("ClientPluginAuth is not supported")
    }

    // auth_plugin_data_len (1 byte)
    authPluginDataLen := int(payload[pos])
    fmt.Println("[authPluginDataLen]", authPluginDataLen)
    pos += 1

    // reserved (10 bytes)
    pos += 10

    // auth_plugin_data_part_2
    authPluginDataPart2Len := max(13, authPluginDataLen-8)
    authPluginDataPart2 := payload[pos : pos+authPluginDataPart2Len]
    fmt.Println("[authPluginDataPart2]", authPluginDataPart2)
    fmt.Printf("    %s", hex.Dump(authPluginDataPart2))
    pos += authPluginDataPart2Len

    // auth_plugin_name (NUL-terminated string)
    idx2 := bytes.IndexByte(payload[pos:], NUL)
    if idx2 == -1 {
        log.Fatal("NUL not found")
    }
    authPluginName := payload[pos : pos+idx2]
    fmt.Println("[authPluginName]", authPluginName)
    fmt.Printf("    %s\n", authPluginName)
    pos += idx2 + 1

    // ========================================
    // Handshake Response Packet を作成
    // ========================================

    // 認証トークンを生成
    // SHA1(password) XOR SHA1(scramble + SHA1(SHA1(password)))
    authResponse := sha1Hash([]byte(PASSWORD))
    scramble := append(authPluginDataPart1, authPluginDataPart2[0:12]...)
    tmp := sha1Hash(append(scramble, sha1Hash(authResponse)...))
    for i := range 20 {
        authResponse[i] ^= tmp[i]
    }

    // パケットバッファを作成
    bufsize := calculateBufferSize(authResponse, string(authPluginName))
    buf := make([]byte, bufsize)

    pos = 0
    payloadsize := bufsize - 4

    // Header: payload_length (3 bytes) + sequence_id (1 byte)
    buf[0] = byte(payloadsize)
    buf[1] = byte(payloadsize >> 8)
    buf[2] = byte(payloadsize >> 16)
    buf[3] = 0x01 // sequence_id
    pos += 4

    // capability_flags (4 bytes)
    // CLIENT_PLUGIN_AUTH または CLIENT_SECURE_CONNECTION は
    // auth_response_length を使用する場合に必要
    capabilityFlags := uint32(CLIENT_PROTOCOL_41)
    buf[pos] = byte(capabilityFlags)
    buf[pos+1] = byte(capabilityFlags >> 8)
    buf[pos+2] = byte(capabilityFlags >> 16)
    buf[pos+3] = byte(capabilityFlags >> 24)
    pos += 4

    // max_packet_size (4 bytes) - ゼロのまま
    pos += 4

    // character_set (1 byte) - 0x33 = utf8_general_ci
    buf[pos] = 0x33
    pos++

    // reserved (23 bytes) - ゼロのまま
    pos += 23

    // username (NUL-terminated string)
    copy(buf[pos:], USERNAME)
    pos += len(USERNAME) + 1

    // auth_response_length (1 byte)
    buf[pos] = byte(len(authResponse))
    pos++

    // auth_response
    copy(buf[pos:], authResponse)
    pos += len(authResponse)

    // client_plugin_name (NUL-terminated string)
    copy(buf[pos:], authPluginName)

    // ========================================
    // Handshake Response を送信
    // ========================================
    fmt.Printf("\n[write handshake response packet]: \n%s\n", hex.Dump(buf))
    if _, err := conn.Write(buf); err != nil {
        log.Fatal(err)
    }

    // ========================================
    // 認証結果を受信
    // ========================================
    header = make([]byte, 4)
    if _, err := io.ReadFull(conn, header); err != nil {
        log.Fatal(err)
    }
    fmt.Println("[header]")
    fmt.Printf("%s\n", hex.Dump(header))

    payloadSize := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
    payload = make([]byte, payloadSize)
    if _, err := io.ReadFull(conn, payload); err != nil {
        log.Fatal(err)
    }
    fmt.Println("[payload]")
    fmt.Printf("%s\n", hex.Dump(payload))

    // 認証結果を判定
    switch payload[0] {
    case 0x00:
        fmt.Println("\n✅ Successfully authenticated to MySQL server")
    case 0xFF:
        fmt.Printf("\n❌ Authentication failed: %s\n", string(payload[3:]))
    }
}

// SHA1ハッシュを計算
func sha1Hash(data []byte) []byte {
    hash := sha1.Sum(data)
    return hash[:]
}

// Handshake Response Packetのサイズを計算
func calculateBufferSize(authResponse []byte, authPluginName string) int {
    size := 4                       // header
    size += 4                       // client_flag
    size += 4                       // max_packet_size
    size += 1                       // character_set
    size += 23                      // filler
    size += len(USERNAME) + 1       // username + NUL
    size++                          // auth_response_length
    size += len(authResponse)       // auth_response
    size += len(authPluginName) + 1 // client_plugin_name + NUL
    return size
}

おわりに

今回は MySQL プロトコルの Connection Phase を実装することで、サーバーとクライアント間でどのような情報がやり取りされているかを理解できました。 今回の実装を踏まえて、MySQLドライバ(go-sql-driver/mysql )の実装を眺めてみましたが、上手いコードの書き方や工夫が見られて学びになりそうでした。 また、バイナリを便利に扱うために hex, bytes パッケージを色々と調べて知らない便利メソッドを見つけて楽しめました。次回は Command Phase の実装や、caching_sha2_password プラグインにも挑戦したいと思います。

参考資料

We Are Hiring !!

UPSIDERでは現在積極採用をしています。 ぜひお気軽にご応募ください。

herp.careers

herp.careers

UPSIDER Engineering Deckはこちら📣

speakerdeck.com


  1. caching_sha2_password ではcache hitの有無やTLS接続を考慮する必要があり、今回は簡単のためにmysql_native_passwordを使用します。
  2. ここで紹介するパケット構造は今回の検証で必要なフィールドのみに絞っています。MySQLサーバーの設定によっては、capability_flagsなどの内容が異なる場合があります。