カード事業部でバックエンドエンジニアをしている Mimura です。弊社ではマイクロサービスごとに要件や特性に合ったデータベースが選定されますが、私の関わるサービスではMySQLが選定されることが多いです。
普段は database/sql や ORMライブラリを通じて当たり前のようにMySQLとやり取りしていますが、その裏側でどのような通信が流れているのかは意識する機会が多くありません。そこで今回は、MySQLのClient/Server Protocolを自前で実装し、TCPソケットでMySQLサーバーに直接接続してハンドシェイクと認証を行うクライアントをGoで書いてみました。
本記事ではMySQLサーバーに対して接続を確立し、ハンドシェイクと認証を行うクライアントを実装してみます。実装を通じて以下のような学びが得られると思います。
- バイナリプロトコルの読み書き
- チャレンジレスポンス認証の仕組み(mysql_native_password)
- パケットベースの通信プロトコルの理解
準備
バックグラウンドでMySQLコンテナを起動しておきます。
docker run -d \ --name mysql \ -p 3306:3306 \ -e MYSQL_ROOT_PASSWORD=password \ -e MYSQL_USER=user \ -e MYSQL_PASSWORD=password \ mysql:8.0.43 \ --default-authentication-plugin=mysql_native_password
本記事執筆時点での最新LTSはMySQL 8.4ですが、簡単に mysql_native_password プラグインで認証にするために今回はMySQL 8.0.43を使用します1。以降では、このMySQLコンテナに対してクライアントを実装し、通信を試みます。
MySQLプロトコルの基礎知識
MySQLがクライアントとサーバー間の通信を行う際には、その名の通りのMySQLプロトコルが使用されます。MySQLプロトコルはTCPの上でパケットという単位でデータがやり取りされます。パケットは以下のような構造をしています。
| Type | Name | Description |
|---|---|---|
| int<3> | payload_length | ペイロードの長さ |
| int<1> | sequence_id | シーケンス番号 |
| string<var> | payload | 実際のデータ |
ここで、int<num> は num バイトの整数を表し、string<var> は可変長の文字列を表します。
それぞれ役割は名前の通りで、先頭の4バイトがヘッダーとして扱われます。すなわち、パケットのやり取りでは先頭4バイトのヘッダーを読んでpayload_lengthを取得し、その後に続くデータをpayloadとして読み取ることになりそうだなと分かります。
例えば、01 00 00 00 01 というパケットは、payload_lengthが1、sequence_idが0、payloadが 01 であることを表します。(payload_length は Little endianで表されていることに注意)
基本となるデータ構造を理解したところで、MySQLプロトコルの具体的な通信を見ていきます。MySQLの通信は Connection Phase と Command Phase の2つのフェーズからなります。
Connection Phase
Connection Phaseは、クライアントとサーバーが接続を確立し認証を行うフェーズです。このフェーズは以下の3つのステップで構成されます。2
1. Initial Handshake
クライアントがサーバーに TCP接続を確立すると、サーバーは最初に Initial Handshake Packet を送信してきます。このパケットには、クライアントが認証を行うために必要な情報が含まれています。
Initial Handshake Packetの主な構成要素は以下の通りです:
| Type | Name | Description |
|---|---|---|
| int<1> | protocol_version | プロトコルバージョン(通常は10) |
| string<NUL> | server_version | MySQLのバージョン文字列 |
| int<4> | connection_id | この接続に割り当てられた一意のID |
| string<8> | auth_plugin_data_part_1 | 認証用の scramble(前半8バイト) |
| int<1> | filler | 固定値 0x00 |
| int<2> | capability_flags_1 | サーバーの機能フラグ(下位2バイト) |
| int<1> | character_set | サーバーのデフォルト文字セット |
| int<2> | status_flags | サーバーのステータス |
| int<2> | capability_flags_2 | サーバーの機能フラグ(上位2バイト) |
| int<1> | auth_plugin_data_len | auth_plugin_dataの長さ |
| string<10> | reserved | 予約領域(0埋め) |
| string<$len> | auth_plugin_data_part_2 | 認証用の scramble(後半部分)。ここで、$len=MAX(13, length of auth_plugin_data - 8) である。 |
| string<NUL> | auth_plugin_name | 使用する認証プラグイン名(今回は mysql_native_password ) |
ここで、string<_NUL_>はNULが出現するまでの文字列長を表します。
たくさんフィールドがありますが、Handshake Response Packet を作成するために必要なのはauth_plugin_data_part_1, auth_plugin_data_part_2, auth_plugin_name だけなので、あまり身構えなくても大丈夫です(後述)。
2. Handshake Response
クライアントは Initial Handshake Packet の情報を基に、Handshake Response Packet を作成してサーバーに送信します。このパケットには、ユーザー名、認証トークン、使用する認証プラグイン名などが含まれます。
Handshake Response Packetの主な構成要素:
| Type | Name | Description |
|---|---|---|
| int<4> | capability_flags | クライアントがサポートする機能フラグ |
| int<4> | max_packet_size | クライアントが受信可能な最大パケットサイズ |
| int<1> | character_set | クライアントの文字セット |
| string<23> | reserved | 予約領域(0埋め) |
| string<NUL> | username | ユーザー名 |
| int<1> | auth_response_length | 認証トークンの長さ |
| string<lenenc> | auth_response | 認証トークン |
| string<NUL> | client_plugin_name | 使用する認証プラグイン名 |
認証トークン(auth_response)は、パスワードとInitial Handshake Packet で受信した scramble を使って生成されます。mysql_native_password 認証プラグインの場合、以下のように生成します。
SHA1(password) XOR SHA1(scramble + SHA1(SHA1(password)))
3. Authentication Response
Handshake Response Packet を受信したサーバーは、認証を試みます。認証が成功すれば OK Packet を、失敗すれば ERR Packet を返します。
- OK Packet: 先頭バイトが
0x00で、認証成功を表します - ERR Packet: 先頭バイトが
0xFFで、エラーコードとエラーメッセージが含まれます
OK Packet を受信したら、Connection Phase は完了し、Command Phase に移行します。
Command Phase
サーバーに対してコマンドを送信し、結果を取得するフェーズです。
クエリの実行は COM_QUERY command code(0x03)とその後にクエリ文字列を送るだけのシンプルなものです。
クエリの結果は可変長のため、上手く読み取りとパース処理する必要があります。
本記事ではサーバーとの接続までを実装するため、Command Phase についての詳細は省略します。気になる方はこちらのドキュメントをご覧ください。
Connection Phase の実装
それでは、実際にConnection Phase を実装していきます。
先のセクションでパケットの構造について確認しましたが、百聞は一見に如かずということで、Initial Handshake Packet を受信してパケットの中身を見てみます。
conn, err := net.Dial("tcp", "localhost:3306") if err != nil { log.Fatal(err) } defer conn.Close() header := make([]byte, 4) if _, err := io.ReadFull(conn, header); err != nil { log.Fatal(err) } fmt.Println("[header]") fmt.Printf("%s\n", hex.Dump(header)) // Little endian length := int(header[0]) | int(header[1])<<8 | int(header[2])<<16 fmt.Println("[length]", length) payload := make([]byte, length) if _, err := io.ReadFull(conn, payload); err != nil { log.Fatal(err) } fmt.Println("[payload]") fmt.Printf("%s\n", hex.Dump(payload))
実行すると以下のような結果が得られます。
[header] 00000000 4a 00 00 00 |J...| [length] 74 [payload] 00000000 0a 38 2e 30 2e 34 33 00 0b 00 00 00 36 1d 79 54 |.8.0.43.....6.yT| 00000010 04 6a 74 3a 00 ff ff ff 02 00 ff df 15 00 00 00 |.jt:............| 00000020 00 00 00 00 00 00 00 2b 5b 1b 66 25 42 1c 61 41 |.......+[.f%B.aA| 00000030 46 50 72 00 6d 79 73 71 6c 5f 6e 61 74 69 76 65 |FPr.mysql_native| 00000040 5f 70 61 73 73 77 6f 72 64 00 |_password.|
Payloadの先頭からプロトコルバージョンが10、MySQLのバージョンは8.0.43、最後に認証プラグイン名など、先程の説明で出てきたものがちらほらと見えています。ドキュメントと実態が対応していそうな雰囲気を感じられて嬉しいですね。これらをパースしていきましょう。
Initial Handshake Packetのパース
受信したpayloadから必要な情報を抽出していきます。パケットの構造に従って、バイト列を順番に読み取っていきます。
pos := 0 // protocol_version (1 byte) protocolVersion := payload[pos] pos += 1 fmt.Println("[protocolVersion]", protocolVersion) // server_version (NUL-terminated string) idx := bytes.IndexByte(payload[pos:], 0x00) if idx == -1 { log.Fatal("NUL not found") } serverVersion := payload[pos : pos+idx] fmt.Printf("[server_version] %s\n", serverVersion) pos += idx + 1 // connection_id (4 bytes) connectionID := binary.LittleEndian.Uint32(payload[pos : pos+4]) fmt.Println("[connectionID]", connectionID) pos += 4 // auth_plugin_data_part_1 (8 bytes) authPluginDataPart1 := payload[pos : pos+8] pos += 8 // filler (1 byte) pos += 1 // capability_flags (2 + 2 bytes) capabilityFlags1 := payload[pos : pos+2] pos += 2 // character_set (1 byte) charset := payload[pos] pos += 1 // status_flags (2 bytes) statusFlags := binary.LittleEndian.Uint16(payload[pos : pos+2]) pos += 2 // capability_flags_2 (2 bytes) capabilityFlags2 := payload[pos : pos+2] pos += 2 // auth_plugin_data_len (1 byte) authPluginDataLen := int(payload[pos]) pos += 1 // reserved (10 bytes) pos += 10 // auth_plugin_data_part_2 authPluginDataPart2Len := max(13, authPluginDataLen-8) authPluginDataPart2 := payload[pos : pos+authPluginDataPart2Len] pos += authPluginDataPart2Len // auth_plugin_name (NUL-terminated string) idx2 := bytes.IndexByte(payload[pos:], 0x00) if idx2 == -1 { log.Fatal("NUL not found") } authPluginName := payload[pos : pos+idx2] fmt.Printf("[authPluginName] %s\n", authPluginName)
ポイントは以下の通りです:
- NUL終端文字列: string<NUL>はNULが出現するまでの文字列を取得するため、
bytes.IndexByteで0x00の位置を探す - Little Endian:
binary.LittleEndianを使って多バイト整数を読み取る
実行結果:
[protocolVersion] 10 [server_version] 8.0.43 [connectionID] 11 [authPluginName] mysql_native_password
Handshake Response Packetの作成
次に、サーバーに送信するHandshake Response Packetを作成します。まず、認証トークンを生成します。
認証トークンの生成
改めて、mysql_native_password認証での認証トークンの生成は以下のようになります。
// SHA1(password) XOR SHA1(scramble + SHA1(SHA1(password))) func sha1Hash(data []byte) []byte { hash := sha1.Sum(data) return hash[:] } // scramble = auth_plugin_data_part_1 + auth_plugin_data_part_2[0:12] scramble := append(authPluginDataPart1, authPluginDataPart2[0:12]...) // SHA1(password) authResponse := sha1Hash([]byte(PASSWORD)) // SHA1(scramble + SHA1(SHA1(password))) tmp := sha1Hash(append(scramble, sha1Hash(authResponse)...)) // SHA1(password) XOR SHA1(scramble + SHA1(SHA1(password))) for i := range 20 { authResponse[i] ^= tmp[i] }
このXOR演算により、パスワードをそのまま送信することなく、サーバー側で検証可能な認証トークンが生成されます。
パケットの構築
次に、Handshake Response Packetのバイト列を構築します:
// パケットサイズを計算 func calculateBufferSize(authResponse []byte, authPluginName string) int { size := 4 // header size += 4 // client_flag size += 4 // max_packet_size size += 1 // character_set size += 23 // filler size += len(USERNAME) + 1 // username + NUL size++ // auth_response_length size += len(authResponse) // auth_response size += len(authPluginName) + 1 // client_plugin_name + NUL return size } bufsize := calculateBufferSize(authResponse, string(authPluginName)) buf := make([]byte, bufsize) pos = 0 payloadsize := bufsize - 4 // ヘッダーサイズを引いておく // Header: payload_length (3 bytes) + sequence_id (1 byte) buf[0] = byte(payloadsize) buf[1] = byte(payloadsize >> 8) buf[2] = byte(payloadsize >> 16) buf[3] = 0x01 // sequence_id pos += 4 // capability_flags (4 bytes) const ( CLIENT_PROTOCOL_41 = 0x00000200 CLIENT_SECURE_CONNECTION = 0x00008000 ) // CLIENT_SECURE_CONNECTION は auth_response_length を使用する場合に必要 capabilityFlags := uint32(CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION) buf[pos] = byte(capabilityFlags) buf[pos+1] = byte(capabilityFlags >> 8) buf[pos+2] = byte(capabilityFlags >> 16) buf[pos+3] = byte(capabilityFlags >> 24) pos += 4 // max_packet_size (4 bytes) - ゼロのままスキップ pos += 4 // character_set (1 byte) - 0x33 = utf8_general_ci buf[pos] = 0x33 pos++ // reserved (23 bytes) - ゼロのままスキップ pos += 23 // username (NUL-terminated string) copy(buf[pos:], USERNAME) pos += len(USERNAME) + 1 // auth_response_length (1 byte) buf[pos] = byte(len(authResponse)) pos++ // auth_response (20 bytes for SHA1) copy(buf[pos:], authResponse) pos += len(authResponse) // client_plugin_name (NUL-terminated string) copy(buf[pos:], authPluginName)
Handshake Responseの送信と認証結果の確認
作成したパケットをサーバーに送信し、認証結果を受け取ります。
// Handshake Response を送信 fmt.Printf("[write handshake response packet:]: \n%s", hex.Dump(buf)) if _, err := conn.Write(buf); err != nil { log.Fatal(err) } // サーバーからの応答を受信 header = make([]byte, 4) if _, err := io.ReadFull(conn, header); err != nil { log.Fatal(err) } payloadSize := int(uint32(header[0]) | uint32(header[1])>>8 | uint32(header[2])>>16) payload = make([]byte, payloadSize) if _, err := io.ReadFull(conn, payload); err != nil { log.Fatal(err) } fmt.Println("[payload]") fmt.Printf("%s\n", hex.Dump(payload)) if payload[0] == 0x00 { fmt.Println("\n✅ Successfully authenticated to MySQL server") } else if payload[0] == 0xFF { fmt.Printf("\n❌ Authentication failed: %s\n", string(payload[3:])) }
実行結果:
[write packets: 63 bytes]: 00000000 3b 00 00 01 08 02 00 00 00 00 00 01 33 00 00 00 |;...........3...| 00000010 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 |................| 00000020 00 00 00 75 73 65 72 00 14 d8 5e 4a 2d 9a 7f 8e |...user...^J-...| 00000030 16 7c 42 ea 4b 9f 3d 5a 1f 23 d9 6f 4c 6d 79 73 |.|B.K.=Z.#.oLmys| 00000040 71 6c 5f 6e 61 74 69 76 65 5f 70 61 73 73 77 6f |ql_native_passwo| 00000050 72 64 00 |rd.| [header] 00000000 07 00 00 02 |....| [payload] 00000000 00 00 00 02 00 00 00 |.......| ✅ Successfully authenticated to MySQL server
先頭バイトが0x00のOK Packetが返ってくれば認証に成功したことになります。
最終コード
package main import ( "bytes" "crypto/sha1" "encoding/binary" "encoding/hex" "fmt" "io" "log" "net" ) const ( USERNAME = "user" PASSWORD = "password" NUL = 0x00 CLIENT_PLUGIN_AUTH = 0x00080000 CLIENT_PROTOCOL_41 = 0x00000200 ) func main() { conn, err := net.Dial("tcp", "localhost:3306") if err != nil { log.Fatal(err) } defer conn.Close() // ======================================== // Initial Handshake Packet を受信 // ======================================== header := make([]byte, 4) if _, err := io.ReadFull(conn, header); err != nil { log.Fatal(err) } fmt.Println("[header]") fmt.Printf("%s\n", hex.Dump(header)) // Little endian でペイロード長を取得 length := int(header[0]) | int(header[1])<<8 | int(header[2])<<16 fmt.Println("[length]", length) payload := make([]byte, length) if _, err := io.ReadFull(conn, payload); err != nil { log.Fatal(err) } fmt.Println("[payload]") fmt.Printf("%s\n", hex.Dump(payload)) // ======================================== // Initial Handshake Packet をパース // ======================================== pos := 0 // protocol_version (1 byte) protocolVersion := payload[pos] pos += 1 fmt.Println("[protocolVersion]", protocolVersion) // server_version (NUL-terminated string) idx := bytes.IndexByte(payload[pos:], NUL) if idx == -1 { log.Fatal("NUL not found") } serverVersion := payload[pos : pos+idx] fmt.Printf("[server_version] %s\n", serverVersion) pos += idx + 1 // connection_id (4 bytes) connectionID := binary.LittleEndian.Uint32(payload[pos : pos+4]) fmt.Println("[connectionID]", connectionID) pos += 4 // auth_plugin_data_part_1 (8 bytes) authPluginDataPart1 := payload[pos : pos+8] fmt.Println("[authPluginDataPart1]", authPluginDataPart1) fmt.Printf(" %s", hex.Dump(authPluginDataPart1)) pos += 8 // filler (1 byte) pos += 1 // capability_flags_1 (2 bytes) capabilityFlags1 := payload[pos : pos+2] fmt.Println("[capabilityFlags1]", capabilityFlags1) pos += 2 // character_set (1 byte) charset := payload[pos] fmt.Println("[charset]", charset) pos += 1 // status_flags (2 bytes) statusFlags := binary.LittleEndian.Uint16(payload[pos : pos+2]) fmt.Println("[statusFlags]", statusFlags) pos += 2 // capability_flags_2 (2 bytes) capabilityFlags2 := payload[pos : pos+2] fmt.Println("[capabilityFlags2]", capabilityFlags2) pos += 2 // capabilities を計算 capabilities := binary.LittleEndian.Uint32(append(capabilityFlags1, capabilityFlags2...)) fmt.Println("[capabilities]", capabilities) // CLIENT_PLUGIN_AUTH のサポートチェック if capabilities&CLIENT_PLUGIN_AUTH == 0 { log.Fatal("ClientPluginAuth is not supported") } // auth_plugin_data_len (1 byte) authPluginDataLen := int(payload[pos]) fmt.Println("[authPluginDataLen]", authPluginDataLen) pos += 1 // reserved (10 bytes) pos += 10 // auth_plugin_data_part_2 authPluginDataPart2Len := max(13, authPluginDataLen-8) authPluginDataPart2 := payload[pos : pos+authPluginDataPart2Len] fmt.Println("[authPluginDataPart2]", authPluginDataPart2) fmt.Printf(" %s", hex.Dump(authPluginDataPart2)) pos += authPluginDataPart2Len // auth_plugin_name (NUL-terminated string) idx2 := bytes.IndexByte(payload[pos:], NUL) if idx2 == -1 { log.Fatal("NUL not found") } authPluginName := payload[pos : pos+idx2] fmt.Println("[authPluginName]", authPluginName) fmt.Printf(" %s\n", authPluginName) pos += idx2 + 1 // ======================================== // Handshake Response Packet を作成 // ======================================== // 認証トークンを生成 // SHA1(password) XOR SHA1(scramble + SHA1(SHA1(password))) authResponse := sha1Hash([]byte(PASSWORD)) scramble := append(authPluginDataPart1, authPluginDataPart2[0:12]...) tmp := sha1Hash(append(scramble, sha1Hash(authResponse)...)) for i := range 20 { authResponse[i] ^= tmp[i] } // パケットバッファを作成 bufsize := calculateBufferSize(authResponse, string(authPluginName)) buf := make([]byte, bufsize) pos = 0 payloadsize := bufsize - 4 // Header: payload_length (3 bytes) + sequence_id (1 byte) buf[0] = byte(payloadsize) buf[1] = byte(payloadsize >> 8) buf[2] = byte(payloadsize >> 16) buf[3] = 0x01 // sequence_id pos += 4 // capability_flags (4 bytes) // CLIENT_PLUGIN_AUTH または CLIENT_SECURE_CONNECTION は // auth_response_length を使用する場合に必要 capabilityFlags := uint32(CLIENT_PROTOCOL_41) buf[pos] = byte(capabilityFlags) buf[pos+1] = byte(capabilityFlags >> 8) buf[pos+2] = byte(capabilityFlags >> 16) buf[pos+3] = byte(capabilityFlags >> 24) pos += 4 // max_packet_size (4 bytes) - ゼロのまま pos += 4 // character_set (1 byte) - 0x33 = utf8_general_ci buf[pos] = 0x33 pos++ // reserved (23 bytes) - ゼロのまま pos += 23 // username (NUL-terminated string) copy(buf[pos:], USERNAME) pos += len(USERNAME) + 1 // auth_response_length (1 byte) buf[pos] = byte(len(authResponse)) pos++ // auth_response copy(buf[pos:], authResponse) pos += len(authResponse) // client_plugin_name (NUL-terminated string) copy(buf[pos:], authPluginName) // ======================================== // Handshake Response を送信 // ======================================== fmt.Printf("\n[write handshake response packet]: \n%s\n", hex.Dump(buf)) if _, err := conn.Write(buf); err != nil { log.Fatal(err) } // ======================================== // 認証結果を受信 // ======================================== header = make([]byte, 4) if _, err := io.ReadFull(conn, header); err != nil { log.Fatal(err) } fmt.Println("[header]") fmt.Printf("%s\n", hex.Dump(header)) payloadSize := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) payload = make([]byte, payloadSize) if _, err := io.ReadFull(conn, payload); err != nil { log.Fatal(err) } fmt.Println("[payload]") fmt.Printf("%s\n", hex.Dump(payload)) // 認証結果を判定 switch payload[0] { case 0x00: fmt.Println("\n✅ Successfully authenticated to MySQL server") case 0xFF: fmt.Printf("\n❌ Authentication failed: %s\n", string(payload[3:])) } } // SHA1ハッシュを計算 func sha1Hash(data []byte) []byte { hash := sha1.Sum(data) return hash[:] } // Handshake Response Packetのサイズを計算 func calculateBufferSize(authResponse []byte, authPluginName string) int { size := 4 // header size += 4 // client_flag size += 4 // max_packet_size size += 1 // character_set size += 23 // filler size += len(USERNAME) + 1 // username + NUL size++ // auth_response_length size += len(authResponse) // auth_response size += len(authPluginName) + 1 // client_plugin_name + NUL return size }
おわりに
今回は MySQL プロトコルの Connection Phase を実装することで、サーバーとクライアント間でどのような情報がやり取りされているかを理解できました。 今回の実装を踏まえて、MySQLドライバ(go-sql-driver/mysql )の実装を眺めてみましたが、上手いコードの書き方や工夫が見られて学びになりそうでした。 また、バイナリを便利に扱うために hex, bytes パッケージを色々と調べて知らない便利メソッドを見つけて楽しめました。次回は Command Phase の実装や、caching_sha2_password プラグインにも挑戦したいと思います。
参考資料
- https://gihyo.jp/dev/serial/01/mysql-road-construction-news/0078
- https://dev.mysql.com/doc/dev/mysql-server/8.0.43/page_protocol_connection_phase_authentication_methods_native_password_authentication.html
We Are Hiring !!
UPSIDERでは現在積極採用をしています。 ぜひお気軽にご応募ください。
UPSIDER Engineering Deckはこちら📣