@@ -294,7 +294,7 @@ | |||||
[[projects]] | [[projects]] | ||||
name = "github.com/go-sql-driver/mysql" | name = "github.com/go-sql-driver/mysql" | ||||
packages = ["."] | packages = ["."] | ||||
revision = "ce924a41eea897745442daaa1739089b0f3f561d" | |||||
revision = "d523deb1b23d913de5bdada721a6071e71283618" | |||||
[[projects]] | [[projects]] | ||||
name = "github.com/go-xorm/builder" | name = "github.com/go-xorm/builder" | ||||
@@ -873,6 +873,6 @@ | |||||
[solve-meta] | [solve-meta] | ||||
analyzer-name = "dep" | analyzer-name = "dep" | ||||
analyzer-version = 1 | analyzer-version = 1 | ||||
inputs-digest = "036b8c882671cf8d2c5e2fdbe53b1bdfbd39f7ebd7765bd50276c7c4ecf16687" | |||||
inputs-digest = "96c83a3502bd50c5ca8e4d9b4145172267630270e587c79b7253156725eeb9b8" | |||||
solver-name = "gps-cdcl" | solver-name = "gps-cdcl" | ||||
solver-version = 1 | solver-version = 1 |
@@ -41,6 +41,10 @@ ignored = ["google.golang.org/appengine*"] | |||||
revision = "d4149d1eee0c2c488a74a5863fd9caf13d60fd03" | revision = "d4149d1eee0c2c488a74a5863fd9caf13d60fd03" | ||||
[[override]] | [[override]] | ||||
name = "github.com/go-sql-driver/mysql" | |||||
revision = "d523deb1b23d913de5bdada721a6071e71283618" | |||||
[[override]] | |||||
name = "github.com/gorilla/mux" | name = "github.com/gorilla/mux" | ||||
revision = "757bef944d0f21880861c2dd9c871ca543023cba" | revision = "757bef944d0f21880861c2dd9c871ca543023cba" | ||||
@@ -12,34 +12,63 @@ | |||||
# Individual Persons | # Individual Persons | ||||
Aaron Hopkins <go-sql-driver at die.net> | Aaron Hopkins <go-sql-driver at die.net> | ||||
Achille Roussel <achille.roussel at gmail.com> | |||||
Alexey Palazhchenko <alexey.palazhchenko at gmail.com> | |||||
Andrew Reid <andrew.reid at tixtrack.com> | |||||
Arne Hormann <arnehormann at gmail.com> | Arne Hormann <arnehormann at gmail.com> | ||||
Asta Xie <xiemengjun at gmail.com> | |||||
Bulat Gaifullin <gaifullinbf at gmail.com> | |||||
Carlos Nieto <jose.carlos at menteslibres.net> | Carlos Nieto <jose.carlos at menteslibres.net> | ||||
Chris Moos <chris at tech9computers.com> | Chris Moos <chris at tech9computers.com> | ||||
Craig Wilson <craiggwilson at gmail.com> | |||||
Daniel Montoya <dsmontoyam at gmail.com> | |||||
Daniel Nichter <nil at codenode.com> | Daniel Nichter <nil at codenode.com> | ||||
Daniël van Eeden <git at myname.nl> | Daniël van Eeden <git at myname.nl> | ||||
Dave Protasowski <dprotaso at gmail.com> | |||||
DisposaBoy <disposaboy at dby.me> | DisposaBoy <disposaboy at dby.me> | ||||
Egor Smolyakov <egorsmkv at gmail.com> | |||||
Evan Shaw <evan at vendhq.com> | |||||
Frederick Mayle <frederickmayle at gmail.com> | Frederick Mayle <frederickmayle at gmail.com> | ||||
Gustavo Kristic <gkristic at gmail.com> | Gustavo Kristic <gkristic at gmail.com> | ||||
Hajime Nakagami <nakagami at gmail.com> | |||||
Hanno Braun <mail at hannobraun.com> | Hanno Braun <mail at hannobraun.com> | ||||
Henri Yandell <flamefew at gmail.com> | Henri Yandell <flamefew at gmail.com> | ||||
Hirotaka Yamamoto <ymmt2005 at gmail.com> | Hirotaka Yamamoto <ymmt2005 at gmail.com> | ||||
ICHINOSE Shogo <shogo82148 at gmail.com> | |||||
INADA Naoki <songofacandy at gmail.com> | INADA Naoki <songofacandy at gmail.com> | ||||
Jacek Szwec <szwec.jacek at gmail.com> | |||||
James Harr <james.harr at gmail.com> | James Harr <james.harr at gmail.com> | ||||
Jeff Hodges <jeff at somethingsimilar.com> | |||||
Jeffrey Charles <jeffreycharles at gmail.com> | |||||
Jian Zhen <zhenjl at gmail.com> | Jian Zhen <zhenjl at gmail.com> | ||||
Joshua Prunier <joshua.prunier at gmail.com> | Joshua Prunier <joshua.prunier at gmail.com> | ||||
Julien Lefevre <julien.lefevr at gmail.com> | Julien Lefevre <julien.lefevr at gmail.com> | ||||
Julien Schmidt <go-sql-driver at julienschmidt.com> | Julien Schmidt <go-sql-driver at julienschmidt.com> | ||||
Justin Li <jli at j-li.net> | |||||
Justin Nuß <nuss.justin at gmail.com> | |||||
Kamil Dziedzic <kamil at klecza.pl> | Kamil Dziedzic <kamil at klecza.pl> | ||||
Kevin Malachowski <kevin at chowski.com> | Kevin Malachowski <kevin at chowski.com> | ||||
Kieron Woodhouse <kieron.woodhouse at infosum.com> | |||||
Lennart Rudolph <lrudolph at hmc.edu> | Lennart Rudolph <lrudolph at hmc.edu> | ||||
Leonardo YongUk Kim <dalinaum at gmail.com> | Leonardo YongUk Kim <dalinaum at gmail.com> | ||||
Linh Tran Tuan <linhduonggnu at gmail.com> | |||||
Lion Yang <lion at aosc.xyz> | |||||
Luca Looz <luca.looz92 at gmail.com> | Luca Looz <luca.looz92 at gmail.com> | ||||
Lucas Liu <extrafliu at gmail.com> | Lucas Liu <extrafliu at gmail.com> | ||||
Luke Scott <luke at webconnex.com> | Luke Scott <luke at webconnex.com> | ||||
Maciej Zimnoch <maciej.zimnoch at codilime.com> | |||||
Michael Woolnough <michael.woolnough at gmail.com> | Michael Woolnough <michael.woolnough at gmail.com> | ||||
Nicola Peduzzi <thenikso at gmail.com> | Nicola Peduzzi <thenikso at gmail.com> | ||||
Olivier Mengué <dolmen at cpan.org> | |||||
oscarzhao <oscarzhaosl at gmail.com> | |||||
Paul Bonser <misterpib at gmail.com> | Paul Bonser <misterpib at gmail.com> | ||||
Peter Schultz <peter.schultz at classmarkets.com> | |||||
Rebecca Chin <rchin at pivotal.io> | |||||
Reed Allman <rdallman10 at gmail.com> | |||||
Richard Wilkes <wilkes at me.com> | |||||
Robert Russell <robert at rrbrussell.com> | |||||
Runrioter Wung <runrioter at gmail.com> | Runrioter Wung <runrioter at gmail.com> | ||||
Shuode Li <elemount at qq.com> | |||||
Soroush Pour <me at soroushjp.com> | Soroush Pour <me at soroushjp.com> | ||||
Stan Putrya <root.vagner at gmail.com> | Stan Putrya <root.vagner at gmail.com> | ||||
Stanley Gunawan <gunawan.stanley at gmail.com> | Stanley Gunawan <gunawan.stanley at gmail.com> | ||||
@@ -51,5 +80,10 @@ Zhenye Xie <xiezhenye at gmail.com> | |||||
# Organizations | # Organizations | ||||
Barracuda Networks, Inc. | Barracuda Networks, Inc. | ||||
Counting Ltd. | |||||
Google Inc. | Google Inc. | ||||
InfoSum Ltd. | |||||
Keybase Inc. | |||||
Percona LLC | |||||
Pivotal Inc. | |||||
Stripe Inc. | Stripe Inc. |
@@ -11,7 +11,7 @@ | |||||
package mysql | package mysql | ||||
import ( | import ( | ||||
"appengine/cloudsql" | |||||
"google.golang.org/appengine/cloudsql" | |||||
) | ) | ||||
func init() { | func init() { | ||||
@@ -0,0 +1,420 @@ | |||||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package | |||||
// | |||||
// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved. | |||||
// | |||||
// This Source Code Form is subject to the terms of the Mozilla Public | |||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file, | |||||
// You can obtain one at http://mozilla.org/MPL/2.0/. | |||||
package mysql | |||||
import ( | |||||
"crypto/rand" | |||||
"crypto/rsa" | |||||
"crypto/sha1" | |||||
"crypto/sha256" | |||||
"crypto/x509" | |||||
"encoding/pem" | |||||
"sync" | |||||
) | |||||
// server pub keys registry | |||||
var ( | |||||
serverPubKeyLock sync.RWMutex | |||||
serverPubKeyRegistry map[string]*rsa.PublicKey | |||||
) | |||||
// RegisterServerPubKey registers a server RSA public key which can be used to | |||||
// send data in a secure manner to the server without receiving the public key | |||||
// in a potentially insecure way from the server first. | |||||
// Registered keys can afterwards be used adding serverPubKey=<name> to the DSN. | |||||
// | |||||
// Note: The provided rsa.PublicKey instance is exclusively owned by the driver | |||||
// after registering it and may not be modified. | |||||
// | |||||
// data, err := ioutil.ReadFile("mykey.pem") | |||||
// if err != nil { | |||||
// log.Fatal(err) | |||||
// } | |||||
// | |||||
// block, _ := pem.Decode(data) | |||||
// if block == nil || block.Type != "PUBLIC KEY" { | |||||
// log.Fatal("failed to decode PEM block containing public key") | |||||
// } | |||||
// | |||||
// pub, err := x509.ParsePKIXPublicKey(block.Bytes) | |||||
// if err != nil { | |||||
// log.Fatal(err) | |||||
// } | |||||
// | |||||
// if rsaPubKey, ok := pub.(*rsa.PublicKey); ok { | |||||
// mysql.RegisterServerPubKey("mykey", rsaPubKey) | |||||
// } else { | |||||
// log.Fatal("not a RSA public key") | |||||
// } | |||||
// | |||||
func RegisterServerPubKey(name string, pubKey *rsa.PublicKey) { | |||||
serverPubKeyLock.Lock() | |||||
if serverPubKeyRegistry == nil { | |||||
serverPubKeyRegistry = make(map[string]*rsa.PublicKey) | |||||
} | |||||
serverPubKeyRegistry[name] = pubKey | |||||
serverPubKeyLock.Unlock() | |||||
} | |||||
// DeregisterServerPubKey removes the public key registered with the given name. | |||||
func DeregisterServerPubKey(name string) { | |||||
serverPubKeyLock.Lock() | |||||
if serverPubKeyRegistry != nil { | |||||
delete(serverPubKeyRegistry, name) | |||||
} | |||||
serverPubKeyLock.Unlock() | |||||
} | |||||
func getServerPubKey(name string) (pubKey *rsa.PublicKey) { | |||||
serverPubKeyLock.RLock() | |||||
if v, ok := serverPubKeyRegistry[name]; ok { | |||||
pubKey = v | |||||
} | |||||
serverPubKeyLock.RUnlock() | |||||
return | |||||
} | |||||
// Hash password using pre 4.1 (old password) method | |||||
// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c | |||||
type myRnd struct { | |||||
seed1, seed2 uint32 | |||||
} | |||||
const myRndMaxVal = 0x3FFFFFFF | |||||
// Pseudo random number generator | |||||
func newMyRnd(seed1, seed2 uint32) *myRnd { | |||||
return &myRnd{ | |||||
seed1: seed1 % myRndMaxVal, | |||||
seed2: seed2 % myRndMaxVal, | |||||
} | |||||
} | |||||
// Tested to be equivalent to MariaDB's floating point variant | |||||
// http://play.golang.org/p/QHvhd4qved | |||||
// http://play.golang.org/p/RG0q4ElWDx | |||||
func (r *myRnd) NextByte() byte { | |||||
r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal | |||||
r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal | |||||
return byte(uint64(r.seed1) * 31 / myRndMaxVal) | |||||
} | |||||
// Generate binary hash from byte string using insecure pre 4.1 method | |||||
func pwHash(password []byte) (result [2]uint32) { | |||||
var add uint32 = 7 | |||||
var tmp uint32 | |||||
result[0] = 1345345333 | |||||
result[1] = 0x12345671 | |||||
for _, c := range password { | |||||
// skip spaces and tabs in password | |||||
if c == ' ' || c == '\t' { | |||||
continue | |||||
} | |||||
tmp = uint32(c) | |||||
result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8) | |||||
result[1] += (result[1] << 8) ^ result[0] | |||||
add += tmp | |||||
} | |||||
// Remove sign bit (1<<31)-1) | |||||
result[0] &= 0x7FFFFFFF | |||||
result[1] &= 0x7FFFFFFF | |||||
return | |||||
} | |||||
// Hash password using insecure pre 4.1 method | |||||
func scrambleOldPassword(scramble []byte, password string) []byte { | |||||
if len(password) == 0 { | |||||
return nil | |||||
} | |||||
scramble = scramble[:8] | |||||
hashPw := pwHash([]byte(password)) | |||||
hashSc := pwHash(scramble) | |||||
r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1]) | |||||
var out [8]byte | |||||
for i := range out { | |||||
out[i] = r.NextByte() + 64 | |||||
} | |||||
mask := r.NextByte() | |||||
for i := range out { | |||||
out[i] ^= mask | |||||
} | |||||
return out[:] | |||||
} | |||||
// Hash password using 4.1+ method (SHA1) | |||||
func scramblePassword(scramble []byte, password string) []byte { | |||||
if len(password) == 0 { | |||||
return nil | |||||
} | |||||
// stage1Hash = SHA1(password) | |||||
crypt := sha1.New() | |||||
crypt.Write([]byte(password)) | |||||
stage1 := crypt.Sum(nil) | |||||
// scrambleHash = SHA1(scramble + SHA1(stage1Hash)) | |||||
// inner Hash | |||||
crypt.Reset() | |||||
crypt.Write(stage1) | |||||
hash := crypt.Sum(nil) | |||||
// outer Hash | |||||
crypt.Reset() | |||||
crypt.Write(scramble) | |||||
crypt.Write(hash) | |||||
scramble = crypt.Sum(nil) | |||||
// token = scrambleHash XOR stage1Hash | |||||
for i := range scramble { | |||||
scramble[i] ^= stage1[i] | |||||
} | |||||
return scramble | |||||
} | |||||
// Hash password using MySQL 8+ method (SHA256) | |||||
func scrambleSHA256Password(scramble []byte, password string) []byte { | |||||
if len(password) == 0 { | |||||
return nil | |||||
} | |||||
// XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble)) | |||||
crypt := sha256.New() | |||||
crypt.Write([]byte(password)) | |||||
message1 := crypt.Sum(nil) | |||||
crypt.Reset() | |||||
crypt.Write(message1) | |||||
message1Hash := crypt.Sum(nil) | |||||
crypt.Reset() | |||||
crypt.Write(message1Hash) | |||||
crypt.Write(scramble) | |||||
message2 := crypt.Sum(nil) | |||||
for i := range message1 { | |||||
message1[i] ^= message2[i] | |||||
} | |||||
return message1 | |||||
} | |||||
func encryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, error) { | |||||
plain := make([]byte, len(password)+1) | |||||
copy(plain, password) | |||||
for i := range plain { | |||||
j := i % len(seed) | |||||
plain[i] ^= seed[j] | |||||
} | |||||
sha1 := sha1.New() | |||||
return rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil) | |||||
} | |||||
func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) error { | |||||
enc, err := encryptPassword(mc.cfg.Passwd, seed, pub) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
return mc.writeAuthSwitchPacket(enc, false) | |||||
} | |||||
func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, bool, error) { | |||||
switch plugin { | |||||
case "caching_sha2_password": | |||||
authResp := scrambleSHA256Password(authData, mc.cfg.Passwd) | |||||
return authResp, (authResp == nil), nil | |||||
case "mysql_old_password": | |||||
if !mc.cfg.AllowOldPasswords { | |||||
return nil, false, ErrOldPassword | |||||
} | |||||
// Note: there are edge cases where this should work but doesn't; | |||||
// this is currently "wontfix": | |||||
// https://github.com/go-sql-driver/mysql/issues/184 | |||||
authResp := scrambleOldPassword(authData[:8], mc.cfg.Passwd) | |||||
return authResp, true, nil | |||||
case "mysql_clear_password": | |||||
if !mc.cfg.AllowCleartextPasswords { | |||||
return nil, false, ErrCleartextPassword | |||||
} | |||||
// http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html | |||||
// http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html | |||||
return []byte(mc.cfg.Passwd), true, nil | |||||
case "mysql_native_password": | |||||
if !mc.cfg.AllowNativePasswords { | |||||
return nil, false, ErrNativePassword | |||||
} | |||||
// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html | |||||
// Native password authentication only need and will need 20-byte challenge. | |||||
authResp := scramblePassword(authData[:20], mc.cfg.Passwd) | |||||
return authResp, false, nil | |||||
case "sha256_password": | |||||
if len(mc.cfg.Passwd) == 0 { | |||||
return nil, true, nil | |||||
} | |||||
if mc.cfg.tls != nil || mc.cfg.Net == "unix" { | |||||
// write cleartext auth packet | |||||
return []byte(mc.cfg.Passwd), true, nil | |||||
} | |||||
pubKey := mc.cfg.pubKey | |||||
if pubKey == nil { | |||||
// request public key from server | |||||
return []byte{1}, false, nil | |||||
} | |||||
// encrypted password | |||||
enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey) | |||||
return enc, false, err | |||||
default: | |||||
errLog.Print("unknown auth plugin:", plugin) | |||||
return nil, false, ErrUnknownPlugin | |||||
} | |||||
} | |||||
func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { | |||||
// Read Result Packet | |||||
authData, newPlugin, err := mc.readAuthResult() | |||||
if err != nil { | |||||
return err | |||||
} | |||||
// handle auth plugin switch, if requested | |||||
if newPlugin != "" { | |||||
// If CLIENT_PLUGIN_AUTH capability is not supported, no new cipher is | |||||
// sent and we have to keep using the cipher sent in the init packet. | |||||
if authData == nil { | |||||
authData = oldAuthData | |||||
} else { | |||||
// copy data from read buffer to owned slice | |||||
copy(oldAuthData, authData) | |||||
} | |||||
plugin = newPlugin | |||||
authResp, addNUL, err := mc.auth(authData, plugin) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
if err = mc.writeAuthSwitchPacket(authResp, addNUL); err != nil { | |||||
return err | |||||
} | |||||
// Read Result Packet | |||||
authData, newPlugin, err = mc.readAuthResult() | |||||
if err != nil { | |||||
return err | |||||
} | |||||
// Do not allow to change the auth plugin more than once | |||||
if newPlugin != "" { | |||||
return ErrMalformPkt | |||||
} | |||||
} | |||||
switch plugin { | |||||
// https://insidemysql.com/preparing-your-community-connector-for-mysql-8-part-2-sha256/ | |||||
case "caching_sha2_password": | |||||
switch len(authData) { | |||||
case 0: | |||||
return nil // auth successful | |||||
case 1: | |||||
switch authData[0] { | |||||
case cachingSha2PasswordFastAuthSuccess: | |||||
if err = mc.readResultOK(); err == nil { | |||||
return nil // auth successful | |||||
} | |||||
case cachingSha2PasswordPerformFullAuthentication: | |||||
if mc.cfg.tls != nil || mc.cfg.Net == "unix" { | |||||
// write cleartext auth packet | |||||
err = mc.writeAuthSwitchPacket([]byte(mc.cfg.Passwd), true) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
} else { | |||||
pubKey := mc.cfg.pubKey | |||||
if pubKey == nil { | |||||
// request public key from server | |||||
data := mc.buf.takeSmallBuffer(4 + 1) | |||||
data[4] = cachingSha2PasswordRequestPublicKey | |||||
mc.writePacket(data) | |||||
// parse public key | |||||
data, err := mc.readPacket() | |||||
if err != nil { | |||||
return err | |||||
} | |||||
block, _ := pem.Decode(data[1:]) | |||||
pkix, err := x509.ParsePKIXPublicKey(block.Bytes) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
pubKey = pkix.(*rsa.PublicKey) | |||||
} | |||||
// send encrypted password | |||||
err = mc.sendEncryptedPassword(oldAuthData, pubKey) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
} | |||||
return mc.readResultOK() | |||||
default: | |||||
return ErrMalformPkt | |||||
} | |||||
default: | |||||
return ErrMalformPkt | |||||
} | |||||
case "sha256_password": | |||||
switch len(authData) { | |||||
case 0: | |||||
return nil // auth successful | |||||
default: | |||||
block, _ := pem.Decode(authData) | |||||
pub, err := x509.ParsePKIXPublicKey(block.Bytes) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
// send encrypted password | |||||
err = mc.sendEncryptedPassword(oldAuthData, pub.(*rsa.PublicKey)) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
return mc.readResultOK() | |||||
} | |||||
default: | |||||
return nil // auth successful | |||||
} | |||||
return err | |||||
} |
@@ -130,18 +130,18 @@ func (b *buffer) takeBuffer(length int) []byte { | |||||
// smaller than defaultBufSize | // smaller than defaultBufSize | ||||
// Only one buffer (total) can be used at a time. | // Only one buffer (total) can be used at a time. | ||||
func (b *buffer) takeSmallBuffer(length int) []byte { | func (b *buffer) takeSmallBuffer(length int) []byte { | ||||
if b.length == 0 { | |||||
return b.buf[:length] | |||||
if b.length > 0 { | |||||
return nil | |||||
} | } | ||||
return nil | |||||
return b.buf[:length] | |||||
} | } | ||||
// takeCompleteBuffer returns the complete existing buffer. | // takeCompleteBuffer returns the complete existing buffer. | ||||
// This can be used if the necessary buffer size is unknown. | // This can be used if the necessary buffer size is unknown. | ||||
// Only one buffer (total) can be used at a time. | // Only one buffer (total) can be used at a time. | ||||
func (b *buffer) takeCompleteBuffer() []byte { | func (b *buffer) takeCompleteBuffer() []byte { | ||||
if b.length == 0 { | |||||
return b.buf | |||||
if b.length > 0 { | |||||
return nil | |||||
} | } | ||||
return nil | |||||
return b.buf | |||||
} | } |
@@ -9,6 +9,7 @@ | |||||
package mysql | package mysql | ||||
const defaultCollation = "utf8_general_ci" | const defaultCollation = "utf8_general_ci" | ||||
const binaryCollation = "binary" | |||||
// A list of available collations mapped to the internal ID. | // A list of available collations mapped to the internal ID. | ||||
// To update this map use the following MySQL query: | // To update this map use the following MySQL query: | ||||
@@ -10,12 +10,23 @@ package mysql | |||||
import ( | import ( | ||||
"database/sql/driver" | "database/sql/driver" | ||||
"io" | |||||
"net" | "net" | ||||
"strconv" | "strconv" | ||||
"strings" | "strings" | ||||
"time" | "time" | ||||
) | ) | ||||
// a copy of context.Context for Go 1.7 and earlier | |||||
type mysqlContext interface { | |||||
Done() <-chan struct{} | |||||
Err() error | |||||
// defined in context.Context, but not used in this driver: | |||||
// Deadline() (deadline time.Time, ok bool) | |||||
// Value(key interface{}) interface{} | |||||
} | |||||
type mysqlConn struct { | type mysqlConn struct { | ||||
buf buffer | buf buffer | ||||
netConn net.Conn | netConn net.Conn | ||||
@@ -29,7 +40,14 @@ type mysqlConn struct { | |||||
status statusFlag | status statusFlag | ||||
sequence uint8 | sequence uint8 | ||||
parseTime bool | parseTime bool | ||||
strict bool | |||||
// for context support (Go 1.8+) | |||||
watching bool | |||||
watcher chan<- mysqlContext | |||||
closech chan struct{} | |||||
finished chan<- struct{} | |||||
canceled atomicError // set non-nil if conn is canceled | |||||
closed atomicBool // set when conn is closed, before closech is closed | |||||
} | } | ||||
// Handles parameters set in DSN after the connection is established | // Handles parameters set in DSN after the connection is established | ||||
@@ -62,22 +80,41 @@ func (mc *mysqlConn) handleParams() (err error) { | |||||
return | return | ||||
} | } | ||||
func (mc *mysqlConn) markBadConn(err error) error { | |||||
if mc == nil { | |||||
return err | |||||
} | |||||
if err != errBadConnNoWrite { | |||||
return err | |||||
} | |||||
return driver.ErrBadConn | |||||
} | |||||
func (mc *mysqlConn) Begin() (driver.Tx, error) { | func (mc *mysqlConn) Begin() (driver.Tx, error) { | ||||
if mc.netConn == nil { | |||||
return mc.begin(false) | |||||
} | |||||
func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) { | |||||
if mc.closed.IsSet() { | |||||
errLog.Print(ErrInvalidConn) | errLog.Print(ErrInvalidConn) | ||||
return nil, driver.ErrBadConn | return nil, driver.ErrBadConn | ||||
} | } | ||||
err := mc.exec("START TRANSACTION") | |||||
var q string | |||||
if readOnly { | |||||
q = "START TRANSACTION READ ONLY" | |||||
} else { | |||||
q = "START TRANSACTION" | |||||
} | |||||
err := mc.exec(q) | |||||
if err == nil { | if err == nil { | ||||
return &mysqlTx{mc}, err | return &mysqlTx{mc}, err | ||||
} | } | ||||
return nil, err | |||||
return nil, mc.markBadConn(err) | |||||
} | } | ||||
func (mc *mysqlConn) Close() (err error) { | func (mc *mysqlConn) Close() (err error) { | ||||
// Makes Close idempotent | // Makes Close idempotent | ||||
if mc.netConn != nil { | |||||
if !mc.closed.IsSet() { | |||||
err = mc.writeCommandPacket(comQuit) | err = mc.writeCommandPacket(comQuit) | ||||
} | } | ||||
@@ -91,26 +128,39 @@ func (mc *mysqlConn) Close() (err error) { | |||||
// is called before auth or on auth failure because MySQL will have already | // is called before auth or on auth failure because MySQL will have already | ||||
// closed the network connection. | // closed the network connection. | ||||
func (mc *mysqlConn) cleanup() { | func (mc *mysqlConn) cleanup() { | ||||
if !mc.closed.TrySet(true) { | |||||
return | |||||
} | |||||
// Makes cleanup idempotent | // Makes cleanup idempotent | ||||
if mc.netConn != nil { | |||||
if err := mc.netConn.Close(); err != nil { | |||||
errLog.Print(err) | |||||
close(mc.closech) | |||||
if mc.netConn == nil { | |||||
return | |||||
} | |||||
if err := mc.netConn.Close(); err != nil { | |||||
errLog.Print(err) | |||||
} | |||||
} | |||||
func (mc *mysqlConn) error() error { | |||||
if mc.closed.IsSet() { | |||||
if err := mc.canceled.Value(); err != nil { | |||||
return err | |||||
} | } | ||||
mc.netConn = nil | |||||
return ErrInvalidConn | |||||
} | } | ||||
mc.cfg = nil | |||||
mc.buf.nc = nil | |||||
return nil | |||||
} | } | ||||
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { | func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { | ||||
if mc.netConn == nil { | |||||
if mc.closed.IsSet() { | |||||
errLog.Print(ErrInvalidConn) | errLog.Print(ErrInvalidConn) | ||||
return nil, driver.ErrBadConn | return nil, driver.ErrBadConn | ||||
} | } | ||||
// Send command | // Send command | ||||
err := mc.writeCommandPacketStr(comStmtPrepare, query) | err := mc.writeCommandPacketStr(comStmtPrepare, query) | ||||
if err != nil { | if err != nil { | ||||
return nil, err | |||||
return nil, mc.markBadConn(err) | |||||
} | } | ||||
stmt := &mysqlStmt{ | stmt := &mysqlStmt{ | ||||
@@ -144,7 +194,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin | |||||
if buf == nil { | if buf == nil { | ||||
// can not take the buffer. Something must be wrong with the connection | // can not take the buffer. Something must be wrong with the connection | ||||
errLog.Print(ErrBusyBuffer) | errLog.Print(ErrBusyBuffer) | ||||
return "", driver.ErrBadConn | |||||
return "", ErrInvalidConn | |||||
} | } | ||||
buf = buf[:0] | buf = buf[:0] | ||||
argPos := 0 | argPos := 0 | ||||
@@ -257,7 +307,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin | |||||
} | } | ||||
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { | func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { | ||||
if mc.netConn == nil { | |||||
if mc.closed.IsSet() { | |||||
errLog.Print(ErrInvalidConn) | errLog.Print(ErrInvalidConn) | ||||
return nil, driver.ErrBadConn | return nil, driver.ErrBadConn | ||||
} | } | ||||
@@ -271,7 +321,6 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err | |||||
return nil, err | return nil, err | ||||
} | } | ||||
query = prepared | query = prepared | ||||
args = nil | |||||
} | } | ||||
mc.affectedRows = 0 | mc.affectedRows = 0 | ||||
mc.insertId = 0 | mc.insertId = 0 | ||||
@@ -283,32 +332,43 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err | |||||
insertId: int64(mc.insertId), | insertId: int64(mc.insertId), | ||||
}, err | }, err | ||||
} | } | ||||
return nil, err | |||||
return nil, mc.markBadConn(err) | |||||
} | } | ||||
// Internal function to execute commands | // Internal function to execute commands | ||||
func (mc *mysqlConn) exec(query string) error { | func (mc *mysqlConn) exec(query string) error { | ||||
// Send command | // Send command | ||||
err := mc.writeCommandPacketStr(comQuery, query) | |||||
if err != nil { | |||||
return err | |||||
if err := mc.writeCommandPacketStr(comQuery, query); err != nil { | |||||
return mc.markBadConn(err) | |||||
} | } | ||||
// Read Result | // Read Result | ||||
resLen, err := mc.readResultSetHeaderPacket() | resLen, err := mc.readResultSetHeaderPacket() | ||||
if err == nil && resLen > 0 { | |||||
if err = mc.readUntilEOF(); err != nil { | |||||
if err != nil { | |||||
return err | |||||
} | |||||
if resLen > 0 { | |||||
// columns | |||||
if err := mc.readUntilEOF(); err != nil { | |||||
return err | return err | ||||
} | } | ||||
err = mc.readUntilEOF() | |||||
// rows | |||||
if err := mc.readUntilEOF(); err != nil { | |||||
return err | |||||
} | |||||
} | } | ||||
return err | |||||
return mc.discardResults() | |||||
} | } | ||||
func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { | func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { | ||||
if mc.netConn == nil { | |||||
return mc.query(query, args) | |||||
} | |||||
func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) { | |||||
if mc.closed.IsSet() { | |||||
errLog.Print(ErrInvalidConn) | errLog.Print(ErrInvalidConn) | ||||
return nil, driver.ErrBadConn | return nil, driver.ErrBadConn | ||||
} | } | ||||
@@ -322,7 +382,6 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro | |||||
return nil, err | return nil, err | ||||
} | } | ||||
query = prepared | query = prepared | ||||
args = nil | |||||
} | } | ||||
// Send command | // Send command | ||||
err := mc.writeCommandPacketStr(comQuery, query) | err := mc.writeCommandPacketStr(comQuery, query) | ||||
@@ -335,15 +394,22 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro | |||||
rows.mc = mc | rows.mc = mc | ||||
if resLen == 0 { | if resLen == 0 { | ||||
// no columns, no more data | |||||
return emptyRows{}, nil | |||||
rows.rs.done = true | |||||
switch err := rows.NextResultSet(); err { | |||||
case nil, io.EOF: | |||||
return rows, nil | |||||
default: | |||||
return nil, err | |||||
} | |||||
} | } | ||||
// Columns | // Columns | ||||
rows.columns, err = mc.readColumns(resLen) | |||||
rows.rs.columns, err = mc.readColumns(resLen) | |||||
return rows, err | return rows, err | ||||
} | } | ||||
} | } | ||||
return nil, err | |||||
return nil, mc.markBadConn(err) | |||||
} | } | ||||
// Gets the value of the given MySQL System Variable | // Gets the value of the given MySQL System Variable | ||||
@@ -359,7 +425,7 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { | |||||
if err == nil { | if err == nil { | ||||
rows := new(textRows) | rows := new(textRows) | ||||
rows.mc = mc | rows.mc = mc | ||||
rows.columns = []mysqlField{{fieldType: fieldTypeVarChar}} | |||||
rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}} | |||||
if resLen > 0 { | if resLen > 0 { | ||||
// Columns | // Columns | ||||
@@ -375,3 +441,21 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { | |||||
} | } | ||||
return nil, err | return nil, err | ||||
} | } | ||||
// finish is called when the query has canceled. | |||||
func (mc *mysqlConn) cancel(err error) { | |||||
mc.canceled.Set(err) | |||||
mc.cleanup() | |||||
} | |||||
// finish is called when the query has succeeded. | |||||
func (mc *mysqlConn) finish() { | |||||
if !mc.watching || mc.finished == nil { | |||||
return | |||||
} | |||||
select { | |||||
case mc.finished <- struct{}{}: | |||||
mc.watching = false | |||||
case <-mc.closech: | |||||
} | |||||
} |
@@ -0,0 +1,208 @@ | |||||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package | |||||
// | |||||
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. | |||||
// | |||||
// This Source Code Form is subject to the terms of the Mozilla Public | |||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file, | |||||
// You can obtain one at http://mozilla.org/MPL/2.0/. | |||||
// +build go1.8 | |||||
package mysql | |||||
import ( | |||||
"context" | |||||
"database/sql" | |||||
"database/sql/driver" | |||||
) | |||||
// Ping implements driver.Pinger interface | |||||
func (mc *mysqlConn) Ping(ctx context.Context) (err error) { | |||||
if mc.closed.IsSet() { | |||||
errLog.Print(ErrInvalidConn) | |||||
return driver.ErrBadConn | |||||
} | |||||
if err = mc.watchCancel(ctx); err != nil { | |||||
return | |||||
} | |||||
defer mc.finish() | |||||
if err = mc.writeCommandPacket(comPing); err != nil { | |||||
return | |||||
} | |||||
return mc.readResultOK() | |||||
} | |||||
// BeginTx implements driver.ConnBeginTx interface | |||||
func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { | |||||
if err := mc.watchCancel(ctx); err != nil { | |||||
return nil, err | |||||
} | |||||
defer mc.finish() | |||||
if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault { | |||||
level, err := mapIsolationLevel(opts.Isolation) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
} | |||||
return mc.begin(opts.ReadOnly) | |||||
} | |||||
func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { | |||||
dargs, err := namedValueToValue(args) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
if err := mc.watchCancel(ctx); err != nil { | |||||
return nil, err | |||||
} | |||||
rows, err := mc.query(query, dargs) | |||||
if err != nil { | |||||
mc.finish() | |||||
return nil, err | |||||
} | |||||
rows.finish = mc.finish | |||||
return rows, err | |||||
} | |||||
func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { | |||||
dargs, err := namedValueToValue(args) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
if err := mc.watchCancel(ctx); err != nil { | |||||
return nil, err | |||||
} | |||||
defer mc.finish() | |||||
return mc.Exec(query, dargs) | |||||
} | |||||
func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { | |||||
if err := mc.watchCancel(ctx); err != nil { | |||||
return nil, err | |||||
} | |||||
stmt, err := mc.Prepare(query) | |||||
mc.finish() | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
select { | |||||
default: | |||||
case <-ctx.Done(): | |||||
stmt.Close() | |||||
return nil, ctx.Err() | |||||
} | |||||
return stmt, nil | |||||
} | |||||
func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { | |||||
dargs, err := namedValueToValue(args) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
if err := stmt.mc.watchCancel(ctx); err != nil { | |||||
return nil, err | |||||
} | |||||
rows, err := stmt.query(dargs) | |||||
if err != nil { | |||||
stmt.mc.finish() | |||||
return nil, err | |||||
} | |||||
rows.finish = stmt.mc.finish | |||||
return rows, err | |||||
} | |||||
func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { | |||||
dargs, err := namedValueToValue(args) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
if err := stmt.mc.watchCancel(ctx); err != nil { | |||||
return nil, err | |||||
} | |||||
defer stmt.mc.finish() | |||||
return stmt.Exec(dargs) | |||||
} | |||||
func (mc *mysqlConn) watchCancel(ctx context.Context) error { | |||||
if mc.watching { | |||||
// Reach here if canceled, | |||||
// so the connection is already invalid | |||||
mc.cleanup() | |||||
return nil | |||||
} | |||||
if ctx.Done() == nil { | |||||
return nil | |||||
} | |||||
mc.watching = true | |||||
select { | |||||
default: | |||||
case <-ctx.Done(): | |||||
return ctx.Err() | |||||
} | |||||
if mc.watcher == nil { | |||||
return nil | |||||
} | |||||
mc.watcher <- ctx | |||||
return nil | |||||
} | |||||
func (mc *mysqlConn) startWatcher() { | |||||
watcher := make(chan mysqlContext, 1) | |||||
mc.watcher = watcher | |||||
finished := make(chan struct{}) | |||||
mc.finished = finished | |||||
go func() { | |||||
for { | |||||
var ctx mysqlContext | |||||
select { | |||||
case ctx = <-watcher: | |||||
case <-mc.closech: | |||||
return | |||||
} | |||||
select { | |||||
case <-ctx.Done(): | |||||
mc.cancel(ctx.Err()) | |||||
case <-finished: | |||||
case <-mc.closech: | |||||
return | |||||
} | |||||
} | |||||
}() | |||||
} | |||||
func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) { | |||||
nv.Value, err = converter{}.ConvertValue(nv.Value) | |||||
return | |||||
} | |||||
// ResetSession implements driver.SessionResetter. | |||||
// (From Go 1.10) | |||||
func (mc *mysqlConn) ResetSession(ctx context.Context) error { | |||||
if mc.closed.IsSet() { | |||||
return driver.ErrBadConn | |||||
} | |||||
return nil | |||||
} |
@@ -9,7 +9,9 @@ | |||||
package mysql | package mysql | ||||
const ( | const ( | ||||
minProtocolVersion byte = 10 | |||||
defaultAuthPlugin = "mysql_native_password" | |||||
defaultMaxAllowedPacket = 4 << 20 // 4 MiB | |||||
minProtocolVersion = 10 | |||||
maxPacketSize = 1<<24 - 1 | maxPacketSize = 1<<24 - 1 | ||||
timeFormat = "2006-01-02 15:04:05.999999" | timeFormat = "2006-01-02 15:04:05.999999" | ||||
) | ) | ||||
@@ -18,10 +20,11 @@ const ( | |||||
// http://dev.mysql.com/doc/internals/en/client-server-protocol.html | // http://dev.mysql.com/doc/internals/en/client-server-protocol.html | ||||
const ( | const ( | ||||
iOK byte = 0x00 | |||||
iLocalInFile byte = 0xfb | |||||
iEOF byte = 0xfe | |||||
iERR byte = 0xff | |||||
iOK byte = 0x00 | |||||
iAuthMoreData byte = 0x01 | |||||
iLocalInFile byte = 0xfb | |||||
iEOF byte = 0xfe | |||||
iERR byte = 0xff | |||||
) | ) | ||||
// https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags | // https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags | ||||
@@ -87,8 +90,10 @@ const ( | |||||
) | ) | ||||
// https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnType | // https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnType | ||||
type fieldType byte | |||||
const ( | const ( | ||||
fieldTypeDecimal byte = iota | |||||
fieldTypeDecimal fieldType = iota | |||||
fieldTypeTiny | fieldTypeTiny | ||||
fieldTypeShort | fieldTypeShort | ||||
fieldTypeLong | fieldTypeLong | ||||
@@ -107,7 +112,7 @@ const ( | |||||
fieldTypeBit | fieldTypeBit | ||||
) | ) | ||||
const ( | const ( | ||||
fieldTypeJSON byte = iota + 0xf5 | |||||
fieldTypeJSON fieldType = iota + 0xf5 | |||||
fieldTypeNewDecimal | fieldTypeNewDecimal | ||||
fieldTypeEnum | fieldTypeEnum | ||||
fieldTypeSet | fieldTypeSet | ||||
@@ -161,3 +166,9 @@ const ( | |||||
statusInTransReadonly | statusInTransReadonly | ||||
statusSessionStateChanged | statusSessionStateChanged | ||||
) | ) | ||||
const ( | |||||
cachingSha2PasswordRequestPublicKey = 2 | |||||
cachingSha2PasswordFastAuthSuccess = 3 | |||||
cachingSha2PasswordPerformFullAuthentication = 4 | |||||
) |
@@ -4,7 +4,7 @@ | |||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file, | // License, v. 2.0. If a copy of the MPL was not distributed with this file, | ||||
// You can obtain one at http://mozilla.org/MPL/2.0/. | // You can obtain one at http://mozilla.org/MPL/2.0/. | ||||
// Package mysql provides a MySQL driver for Go's database/sql package | |||||
// Package mysql provides a MySQL driver for Go's database/sql package. | |||||
// | // | ||||
// The driver should be used via the database/sql package: | // The driver should be used via the database/sql package: | ||||
// | // | ||||
@@ -20,8 +20,14 @@ import ( | |||||
"database/sql" | "database/sql" | ||||
"database/sql/driver" | "database/sql/driver" | ||||
"net" | "net" | ||||
"sync" | |||||
) | ) | ||||
// watcher interface is used for context support (From Go 1.8) | |||||
type watcher interface { | |||||
startWatcher() | |||||
} | |||||
// MySQLDriver is exported to make the driver directly accessible. | // MySQLDriver is exported to make the driver directly accessible. | ||||
// In general the driver is used via the database/sql package. | // In general the driver is used via the database/sql package. | ||||
type MySQLDriver struct{} | type MySQLDriver struct{} | ||||
@@ -30,12 +36,17 @@ type MySQLDriver struct{} | |||||
// Custom dial functions must be registered with RegisterDial | // Custom dial functions must be registered with RegisterDial | ||||
type DialFunc func(addr string) (net.Conn, error) | type DialFunc func(addr string) (net.Conn, error) | ||||
var dials map[string]DialFunc | |||||
var ( | |||||
dialsLock sync.RWMutex | |||||
dials map[string]DialFunc | |||||
) | |||||
// RegisterDial registers a custom dial function. It can then be used by the | // RegisterDial registers a custom dial function. It can then be used by the | ||||
// network address mynet(addr), where mynet is the registered new network. | // network address mynet(addr), where mynet is the registered new network. | ||||
// addr is passed as a parameter to the dial function. | // addr is passed as a parameter to the dial function. | ||||
func RegisterDial(net string, dial DialFunc) { | func RegisterDial(net string, dial DialFunc) { | ||||
dialsLock.Lock() | |||||
defer dialsLock.Unlock() | |||||
if dials == nil { | if dials == nil { | ||||
dials = make(map[string]DialFunc) | dials = make(map[string]DialFunc) | ||||
} | } | ||||
@@ -52,16 +63,19 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { | |||||
mc := &mysqlConn{ | mc := &mysqlConn{ | ||||
maxAllowedPacket: maxPacketSize, | maxAllowedPacket: maxPacketSize, | ||||
maxWriteSize: maxPacketSize - 1, | maxWriteSize: maxPacketSize - 1, | ||||
closech: make(chan struct{}), | |||||
} | } | ||||
mc.cfg, err = ParseDSN(dsn) | mc.cfg, err = ParseDSN(dsn) | ||||
if err != nil { | if err != nil { | ||||
return nil, err | return nil, err | ||||
} | } | ||||
mc.parseTime = mc.cfg.ParseTime | mc.parseTime = mc.cfg.ParseTime | ||||
mc.strict = mc.cfg.Strict | |||||
// Connect to Server | // Connect to Server | ||||
if dial, ok := dials[mc.cfg.Net]; ok { | |||||
dialsLock.RLock() | |||||
dial, ok := dials[mc.cfg.Net] | |||||
dialsLock.RUnlock() | |||||
if ok { | |||||
mc.netConn, err = dial(mc.cfg.Addr) | mc.netConn, err = dial(mc.cfg.Addr) | ||||
} else { | } else { | ||||
nd := net.Dialer{Timeout: mc.cfg.Timeout} | nd := net.Dialer{Timeout: mc.cfg.Timeout} | ||||
@@ -81,6 +95,11 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { | |||||
} | } | ||||
} | } | ||||
// Call startWatcher for context support (From Go 1.8) | |||||
if s, ok := interface{}(mc).(watcher); ok { | |||||
s.startWatcher() | |||||
} | |||||
mc.buf = newBuffer(mc.netConn) | mc.buf = newBuffer(mc.netConn) | ||||
// Set I/O timeouts | // Set I/O timeouts | ||||
@@ -88,20 +107,31 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { | |||||
mc.writeTimeout = mc.cfg.WriteTimeout | mc.writeTimeout = mc.cfg.WriteTimeout | ||||
// Reading Handshake Initialization Packet | // Reading Handshake Initialization Packet | ||||
cipher, err := mc.readInitPacket() | |||||
authData, plugin, err := mc.readHandshakePacket() | |||||
if err != nil { | if err != nil { | ||||
mc.cleanup() | mc.cleanup() | ||||
return nil, err | return nil, err | ||||
} | } | ||||
// Send Client Authentication Packet | // Send Client Authentication Packet | ||||
if err = mc.writeAuthPacket(cipher); err != nil { | |||||
authResp, addNUL, err := mc.auth(authData, plugin) | |||||
if err != nil { | |||||
// try the default auth plugin, if using the requested plugin failed | |||||
errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error()) | |||||
plugin = defaultAuthPlugin | |||||
authResp, addNUL, err = mc.auth(authData, plugin) | |||||
if err != nil { | |||||
mc.cleanup() | |||||
return nil, err | |||||
} | |||||
} | |||||
if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil { | |||||
mc.cleanup() | mc.cleanup() | ||||
return nil, err | return nil, err | ||||
} | } | ||||
// Handle response to auth packet, switch methods if possible | // Handle response to auth packet, switch methods if possible | ||||
if err = handleAuthResult(mc); err != nil { | |||||
if err = mc.handleAuthResult(authData, plugin); err != nil { | |||||
// Authentication failed and MySQL has already closed the connection | // Authentication failed and MySQL has already closed the connection | ||||
// (https://dev.mysql.com/doc/internals/en/authentication-fails.html). | // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). | ||||
// Do not send COM_QUIT, just cleanup and return the error. | // Do not send COM_QUIT, just cleanup and return the error. | ||||
@@ -134,43 +164,6 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { | |||||
return mc, nil | return mc, nil | ||||
} | } | ||||
func handleAuthResult(mc *mysqlConn) error { | |||||
// Read Result Packet | |||||
cipher, err := mc.readResultOK() | |||||
if err == nil { | |||||
return nil // auth successful | |||||
} | |||||
if mc.cfg == nil { | |||||
return err // auth failed and retry not possible | |||||
} | |||||
// Retry auth if configured to do so. | |||||
if mc.cfg.AllowOldPasswords && err == ErrOldPassword { | |||||
// Retry with old authentication method. Note: there are edge cases | |||||
// where this should work but doesn't; this is currently "wontfix": | |||||
// https://github.com/go-sql-driver/mysql/issues/184 | |||||
if err = mc.writeOldAuthPacket(cipher); err != nil { | |||||
return err | |||||
} | |||||
_, err = mc.readResultOK() | |||||
} else if mc.cfg.AllowCleartextPasswords && err == ErrCleartextPassword { | |||||
// Retry with clear text password for | |||||
// http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html | |||||
// http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html | |||||
if err = mc.writeClearAuthPacket(); err != nil { | |||||
return err | |||||
} | |||||
_, err = mc.readResultOK() | |||||
} else if mc.cfg.AllowNativePasswords && err == ErrNativePassword { | |||||
if err = mc.writeNativeAuthPacket(cipher); err != nil { | |||||
return err | |||||
} | |||||
_, err = mc.readResultOK() | |||||
} | |||||
return err | |||||
} | |||||
func init() { | func init() { | ||||
sql.Register("mysql", &MySQLDriver{}) | sql.Register("mysql", &MySQLDriver{}) | ||||
} | } |
@@ -10,11 +10,13 @@ package mysql | |||||
import ( | import ( | ||||
"bytes" | "bytes" | ||||
"crypto/rsa" | |||||
"crypto/tls" | "crypto/tls" | ||||
"errors" | "errors" | ||||
"fmt" | "fmt" | ||||
"net" | "net" | ||||
"net/url" | "net/url" | ||||
"sort" | |||||
"strconv" | "strconv" | ||||
"strings" | "strings" | ||||
"time" | "time" | ||||
@@ -27,7 +29,9 @@ var ( | |||||
errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations") | errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations") | ||||
) | ) | ||||
// Config is a configuration parsed from a DSN string | |||||
// Config is a configuration parsed from a DSN string. | |||||
// If a new Config is created instead of being parsed from a DSN string, | |||||
// the NewConfig function should be used, which sets default values. | |||||
type Config struct { | type Config struct { | ||||
User string // Username | User string // Username | ||||
Passwd string // Password (requires User) | Passwd string // Password (requires User) | ||||
@@ -38,6 +42,8 @@ type Config struct { | |||||
Collation string // Connection collation | Collation string // Connection collation | ||||
Loc *time.Location // Location for time.Time values | Loc *time.Location // Location for time.Time values | ||||
MaxAllowedPacket int // Max packet size allowed | MaxAllowedPacket int // Max packet size allowed | ||||
ServerPubKey string // Server public key name | |||||
pubKey *rsa.PublicKey // Server public key | |||||
TLSConfig string // TLS configuration name | TLSConfig string // TLS configuration name | ||||
tls *tls.Config // TLS configuration | tls *tls.Config // TLS configuration | ||||
Timeout time.Duration // Dial timeout | Timeout time.Duration // Dial timeout | ||||
@@ -53,7 +59,54 @@ type Config struct { | |||||
InterpolateParams bool // Interpolate placeholders into query string | InterpolateParams bool // Interpolate placeholders into query string | ||||
MultiStatements bool // Allow multiple statements in one query | MultiStatements bool // Allow multiple statements in one query | ||||
ParseTime bool // Parse time values to time.Time | ParseTime bool // Parse time values to time.Time | ||||
Strict bool // Return warnings as errors | |||||
RejectReadOnly bool // Reject read-only connections | |||||
} | |||||
// NewConfig creates a new Config and sets default values. | |||||
func NewConfig() *Config { | |||||
return &Config{ | |||||
Collation: defaultCollation, | |||||
Loc: time.UTC, | |||||
MaxAllowedPacket: defaultMaxAllowedPacket, | |||||
AllowNativePasswords: true, | |||||
} | |||||
} | |||||
func (cfg *Config) normalize() error { | |||||
if cfg.InterpolateParams && unsafeCollations[cfg.Collation] { | |||||
return errInvalidDSNUnsafeCollation | |||||
} | |||||
// Set default network if empty | |||||
if cfg.Net == "" { | |||||
cfg.Net = "tcp" | |||||
} | |||||
// Set default address if empty | |||||
if cfg.Addr == "" { | |||||
switch cfg.Net { | |||||
case "tcp": | |||||
cfg.Addr = "127.0.0.1:3306" | |||||
case "unix": | |||||
cfg.Addr = "/tmp/mysql.sock" | |||||
default: | |||||
return errors.New("default addr for network '" + cfg.Net + "' unknown") | |||||
} | |||||
} else if cfg.Net == "tcp" { | |||||
cfg.Addr = ensureHavePort(cfg.Addr) | |||||
} | |||||
if cfg.tls != nil { | |||||
if cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify { | |||||
host, _, err := net.SplitHostPort(cfg.Addr) | |||||
if err == nil { | |||||
cfg.tls.ServerName = host | |||||
} | |||||
} | |||||
} | |||||
return nil | |||||
} | } | ||||
// FormatDSN formats the given Config into a DSN string which can be passed to | // FormatDSN formats the given Config into a DSN string which can be passed to | ||||
@@ -102,12 +155,12 @@ func (cfg *Config) FormatDSN() string { | |||||
} | } | ||||
} | } | ||||
if cfg.AllowNativePasswords { | |||||
if !cfg.AllowNativePasswords { | |||||
if hasParam { | if hasParam { | ||||
buf.WriteString("&allowNativePasswords=true") | |||||
buf.WriteString("&allowNativePasswords=false") | |||||
} else { | } else { | ||||
hasParam = true | hasParam = true | ||||
buf.WriteString("?allowNativePasswords=true") | |||||
buf.WriteString("?allowNativePasswords=false") | |||||
} | } | ||||
} | } | ||||
@@ -195,15 +248,25 @@ func (cfg *Config) FormatDSN() string { | |||||
buf.WriteString(cfg.ReadTimeout.String()) | buf.WriteString(cfg.ReadTimeout.String()) | ||||
} | } | ||||
if cfg.Strict { | |||||
if cfg.RejectReadOnly { | |||||
if hasParam { | if hasParam { | ||||
buf.WriteString("&strict=true") | |||||
buf.WriteString("&rejectReadOnly=true") | |||||
} else { | } else { | ||||
hasParam = true | hasParam = true | ||||
buf.WriteString("?strict=true") | |||||
buf.WriteString("?rejectReadOnly=true") | |||||
} | } | ||||
} | } | ||||
if len(cfg.ServerPubKey) > 0 { | |||||
if hasParam { | |||||
buf.WriteString("&serverPubKey=") | |||||
} else { | |||||
hasParam = true | |||||
buf.WriteString("?serverPubKey=") | |||||
} | |||||
buf.WriteString(url.QueryEscape(cfg.ServerPubKey)) | |||||
} | |||||
if cfg.Timeout > 0 { | if cfg.Timeout > 0 { | ||||
if hasParam { | if hasParam { | ||||
buf.WriteString("&timeout=") | buf.WriteString("&timeout=") | ||||
@@ -234,7 +297,7 @@ func (cfg *Config) FormatDSN() string { | |||||
buf.WriteString(cfg.WriteTimeout.String()) | buf.WriteString(cfg.WriteTimeout.String()) | ||||
} | } | ||||
if cfg.MaxAllowedPacket > 0 { | |||||
if cfg.MaxAllowedPacket != defaultMaxAllowedPacket { | |||||
if hasParam { | if hasParam { | ||||
buf.WriteString("&maxAllowedPacket=") | buf.WriteString("&maxAllowedPacket=") | ||||
} else { | } else { | ||||
@@ -247,7 +310,12 @@ func (cfg *Config) FormatDSN() string { | |||||
// other params | // other params | ||||
if cfg.Params != nil { | if cfg.Params != nil { | ||||
for param, value := range cfg.Params { | |||||
var params []string | |||||
for param := range cfg.Params { | |||||
params = append(params, param) | |||||
} | |||||
sort.Strings(params) | |||||
for _, param := range params { | |||||
if hasParam { | if hasParam { | ||||
buf.WriteByte('&') | buf.WriteByte('&') | ||||
} else { | } else { | ||||
@@ -257,7 +325,7 @@ func (cfg *Config) FormatDSN() string { | |||||
buf.WriteString(param) | buf.WriteString(param) | ||||
buf.WriteByte('=') | buf.WriteByte('=') | ||||
buf.WriteString(url.QueryEscape(value)) | |||||
buf.WriteString(url.QueryEscape(cfg.Params[param])) | |||||
} | } | ||||
} | } | ||||
@@ -267,10 +335,7 @@ func (cfg *Config) FormatDSN() string { | |||||
// ParseDSN parses the DSN string to a Config | // ParseDSN parses the DSN string to a Config | ||||
func ParseDSN(dsn string) (cfg *Config, err error) { | func ParseDSN(dsn string) (cfg *Config, err error) { | ||||
// New config with some default values | // New config with some default values | ||||
cfg = &Config{ | |||||
Loc: time.UTC, | |||||
Collation: defaultCollation, | |||||
} | |||||
cfg = NewConfig() | |||||
// [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] | // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] | ||||
// Find the last '/' (since the password or the net addr might contain a '/') | // Find the last '/' (since the password or the net addr might contain a '/') | ||||
@@ -338,28 +403,9 @@ func ParseDSN(dsn string) (cfg *Config, err error) { | |||||
return nil, errInvalidDSNNoSlash | return nil, errInvalidDSNNoSlash | ||||
} | } | ||||
if cfg.InterpolateParams && unsafeCollations[cfg.Collation] { | |||||
return nil, errInvalidDSNUnsafeCollation | |||||
} | |||||
// Set default network if empty | |||||
if cfg.Net == "" { | |||||
cfg.Net = "tcp" | |||||
if err = cfg.normalize(); err != nil { | |||||
return nil, err | |||||
} | } | ||||
// Set default address if empty | |||||
if cfg.Addr == "" { | |||||
switch cfg.Net { | |||||
case "tcp": | |||||
cfg.Addr = "127.0.0.1:3306" | |||||
case "unix": | |||||
cfg.Addr = "/tmp/mysql.sock" | |||||
default: | |||||
return nil, errors.New("default addr for network '" + cfg.Net + "' unknown") | |||||
} | |||||
} | |||||
return | return | ||||
} | } | ||||
@@ -374,7 +420,6 @@ func parseDSNParams(cfg *Config, params string) (err error) { | |||||
// cfg params | // cfg params | ||||
switch value := param[1]; param[0] { | switch value := param[1]; param[0] { | ||||
// Disable INFILE whitelist / enable all files | // Disable INFILE whitelist / enable all files | ||||
case "allowAllFiles": | case "allowAllFiles": | ||||
var isBool bool | var isBool bool | ||||
@@ -472,14 +517,32 @@ func parseDSNParams(cfg *Config, params string) (err error) { | |||||
return | return | ||||
} | } | ||||
// Strict mode | |||||
case "strict": | |||||
// Reject read-only connections | |||||
case "rejectReadOnly": | |||||
var isBool bool | var isBool bool | ||||
cfg.Strict, isBool = readBool(value) | |||||
cfg.RejectReadOnly, isBool = readBool(value) | |||||
if !isBool { | if !isBool { | ||||
return errors.New("invalid bool value: " + value) | return errors.New("invalid bool value: " + value) | ||||
} | } | ||||
// Server public key | |||||
case "serverPubKey": | |||||
name, err := url.QueryUnescape(value) | |||||
if err != nil { | |||||
return fmt.Errorf("invalid value for server pub key name: %v", err) | |||||
} | |||||
if pubKey := getServerPubKey(name); pubKey != nil { | |||||
cfg.ServerPubKey = name | |||||
cfg.pubKey = pubKey | |||||
} else { | |||||
return errors.New("invalid value / unknown server pub key name: " + name) | |||||
} | |||||
// Strict mode | |||||
case "strict": | |||||
panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode") | |||||
// Dial Timeout | // Dial Timeout | ||||
case "timeout": | case "timeout": | ||||
cfg.Timeout, err = time.ParseDuration(value) | cfg.Timeout, err = time.ParseDuration(value) | ||||
@@ -506,14 +569,7 @@ func parseDSNParams(cfg *Config, params string) (err error) { | |||||
return fmt.Errorf("invalid value for TLS config name: %v", err) | return fmt.Errorf("invalid value for TLS config name: %v", err) | ||||
} | } | ||||
if tlsConfig, ok := tlsConfigRegister[name]; ok { | |||||
if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify { | |||||
host, _, err := net.SplitHostPort(cfg.Addr) | |||||
if err == nil { | |||||
tlsConfig.ServerName = host | |||||
} | |||||
} | |||||
if tlsConfig := getTLSConfigClone(name); tlsConfig != nil { | |||||
cfg.TLSConfig = name | cfg.TLSConfig = name | ||||
cfg.tls = tlsConfig | cfg.tls = tlsConfig | ||||
} else { | } else { | ||||
@@ -546,3 +602,10 @@ func parseDSNParams(cfg *Config, params string) (err error) { | |||||
return | return | ||||
} | } | ||||
func ensureHavePort(addr string) string { | |||||
if _, _, err := net.SplitHostPort(addr); err != nil { | |||||
return net.JoinHostPort(addr, "3306") | |||||
} | |||||
return addr | |||||
} |
@@ -9,10 +9,8 @@ | |||||
package mysql | package mysql | ||||
import ( | import ( | ||||
"database/sql/driver" | |||||
"errors" | "errors" | ||||
"fmt" | "fmt" | ||||
"io" | |||||
"log" | "log" | ||||
"os" | "os" | ||||
) | ) | ||||
@@ -31,6 +29,12 @@ var ( | |||||
ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?") | ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?") | ||||
ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server") | ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server") | ||||
ErrBusyBuffer = errors.New("busy buffer") | ErrBusyBuffer = errors.New("busy buffer") | ||||
// errBadConnNoWrite is used for connection errors where nothing was sent to the database yet. | |||||
// If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn | |||||
// to trigger a resend. | |||||
// See https://github.com/go-sql-driver/mysql/pull/302 | |||||
errBadConnNoWrite = errors.New("bad connection") | |||||
) | ) | ||||
var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile)) | var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile)) | ||||
@@ -59,74 +63,3 @@ type MySQLError struct { | |||||
func (me *MySQLError) Error() string { | func (me *MySQLError) Error() string { | ||||
return fmt.Sprintf("Error %d: %s", me.Number, me.Message) | return fmt.Sprintf("Error %d: %s", me.Number, me.Message) | ||||
} | } | ||||
// MySQLWarnings is an error type which represents a group of one or more MySQL | |||||
// warnings | |||||
type MySQLWarnings []MySQLWarning | |||||
func (mws MySQLWarnings) Error() string { | |||||
var msg string | |||||
for i, warning := range mws { | |||||
if i > 0 { | |||||
msg += "\r\n" | |||||
} | |||||
msg += fmt.Sprintf( | |||||
"%s %s: %s", | |||||
warning.Level, | |||||
warning.Code, | |||||
warning.Message, | |||||
) | |||||
} | |||||
return msg | |||||
} | |||||
// MySQLWarning is an error type which represents a single MySQL warning. | |||||
// Warnings are returned in groups only. See MySQLWarnings | |||||
type MySQLWarning struct { | |||||
Level string | |||||
Code string | |||||
Message string | |||||
} | |||||
func (mc *mysqlConn) getWarnings() (err error) { | |||||
rows, err := mc.Query("SHOW WARNINGS", nil) | |||||
if err != nil { | |||||
return | |||||
} | |||||
var warnings = MySQLWarnings{} | |||||
var values = make([]driver.Value, 3) | |||||
for { | |||||
err = rows.Next(values) | |||||
switch err { | |||||
case nil: | |||||
warning := MySQLWarning{} | |||||
if raw, ok := values[0].([]byte); ok { | |||||
warning.Level = string(raw) | |||||
} else { | |||||
warning.Level = fmt.Sprintf("%s", values[0]) | |||||
} | |||||
if raw, ok := values[1].([]byte); ok { | |||||
warning.Code = string(raw) | |||||
} else { | |||||
warning.Code = fmt.Sprintf("%s", values[1]) | |||||
} | |||||
if raw, ok := values[2].([]byte); ok { | |||||
warning.Message = string(raw) | |||||
} else { | |||||
warning.Message = fmt.Sprintf("%s", values[0]) | |||||
} | |||||
warnings = append(warnings, warning) | |||||
case io.EOF: | |||||
return warnings | |||||
default: | |||||
rows.Close() | |||||
return | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,194 @@ | |||||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package | |||||
// | |||||
// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. | |||||
// | |||||
// This Source Code Form is subject to the terms of the Mozilla Public | |||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file, | |||||
// You can obtain one at http://mozilla.org/MPL/2.0/. | |||||
package mysql | |||||
import ( | |||||
"database/sql" | |||||
"reflect" | |||||
) | |||||
func (mf *mysqlField) typeDatabaseName() string { | |||||
switch mf.fieldType { | |||||
case fieldTypeBit: | |||||
return "BIT" | |||||
case fieldTypeBLOB: | |||||
if mf.charSet != collations[binaryCollation] { | |||||
return "TEXT" | |||||
} | |||||
return "BLOB" | |||||
case fieldTypeDate: | |||||
return "DATE" | |||||
case fieldTypeDateTime: | |||||
return "DATETIME" | |||||
case fieldTypeDecimal: | |||||
return "DECIMAL" | |||||
case fieldTypeDouble: | |||||
return "DOUBLE" | |||||
case fieldTypeEnum: | |||||
return "ENUM" | |||||
case fieldTypeFloat: | |||||
return "FLOAT" | |||||
case fieldTypeGeometry: | |||||
return "GEOMETRY" | |||||
case fieldTypeInt24: | |||||
return "MEDIUMINT" | |||||
case fieldTypeJSON: | |||||
return "JSON" | |||||
case fieldTypeLong: | |||||
return "INT" | |||||
case fieldTypeLongBLOB: | |||||
if mf.charSet != collations[binaryCollation] { | |||||
return "LONGTEXT" | |||||
} | |||||
return "LONGBLOB" | |||||
case fieldTypeLongLong: | |||||
return "BIGINT" | |||||
case fieldTypeMediumBLOB: | |||||
if mf.charSet != collations[binaryCollation] { | |||||
return "MEDIUMTEXT" | |||||
} | |||||
return "MEDIUMBLOB" | |||||
case fieldTypeNewDate: | |||||
return "DATE" | |||||
case fieldTypeNewDecimal: | |||||
return "DECIMAL" | |||||
case fieldTypeNULL: | |||||
return "NULL" | |||||
case fieldTypeSet: | |||||
return "SET" | |||||
case fieldTypeShort: | |||||
return "SMALLINT" | |||||
case fieldTypeString: | |||||
if mf.charSet == collations[binaryCollation] { | |||||
return "BINARY" | |||||
} | |||||
return "CHAR" | |||||
case fieldTypeTime: | |||||
return "TIME" | |||||
case fieldTypeTimestamp: | |||||
return "TIMESTAMP" | |||||
case fieldTypeTiny: | |||||
return "TINYINT" | |||||
case fieldTypeTinyBLOB: | |||||
if mf.charSet != collations[binaryCollation] { | |||||
return "TINYTEXT" | |||||
} | |||||
return "TINYBLOB" | |||||
case fieldTypeVarChar: | |||||
if mf.charSet == collations[binaryCollation] { | |||||
return "VARBINARY" | |||||
} | |||||
return "VARCHAR" | |||||
case fieldTypeVarString: | |||||
if mf.charSet == collations[binaryCollation] { | |||||
return "VARBINARY" | |||||
} | |||||
return "VARCHAR" | |||||
case fieldTypeYear: | |||||
return "YEAR" | |||||
default: | |||||
return "" | |||||
} | |||||
} | |||||
var ( | |||||
scanTypeFloat32 = reflect.TypeOf(float32(0)) | |||||
scanTypeFloat64 = reflect.TypeOf(float64(0)) | |||||
scanTypeInt8 = reflect.TypeOf(int8(0)) | |||||
scanTypeInt16 = reflect.TypeOf(int16(0)) | |||||
scanTypeInt32 = reflect.TypeOf(int32(0)) | |||||
scanTypeInt64 = reflect.TypeOf(int64(0)) | |||||
scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{}) | |||||
scanTypeNullInt = reflect.TypeOf(sql.NullInt64{}) | |||||
scanTypeNullTime = reflect.TypeOf(NullTime{}) | |||||
scanTypeUint8 = reflect.TypeOf(uint8(0)) | |||||
scanTypeUint16 = reflect.TypeOf(uint16(0)) | |||||
scanTypeUint32 = reflect.TypeOf(uint32(0)) | |||||
scanTypeUint64 = reflect.TypeOf(uint64(0)) | |||||
scanTypeRawBytes = reflect.TypeOf(sql.RawBytes{}) | |||||
scanTypeUnknown = reflect.TypeOf(new(interface{})) | |||||
) | |||||
type mysqlField struct { | |||||
tableName string | |||||
name string | |||||
length uint32 | |||||
flags fieldFlag | |||||
fieldType fieldType | |||||
decimals byte | |||||
charSet uint8 | |||||
} | |||||
func (mf *mysqlField) scanType() reflect.Type { | |||||
switch mf.fieldType { | |||||
case fieldTypeTiny: | |||||
if mf.flags&flagNotNULL != 0 { | |||||
if mf.flags&flagUnsigned != 0 { | |||||
return scanTypeUint8 | |||||
} | |||||
return scanTypeInt8 | |||||
} | |||||
return scanTypeNullInt | |||||
case fieldTypeShort, fieldTypeYear: | |||||
if mf.flags&flagNotNULL != 0 { | |||||
if mf.flags&flagUnsigned != 0 { | |||||
return scanTypeUint16 | |||||
} | |||||
return scanTypeInt16 | |||||
} | |||||
return scanTypeNullInt | |||||
case fieldTypeInt24, fieldTypeLong: | |||||
if mf.flags&flagNotNULL != 0 { | |||||
if mf.flags&flagUnsigned != 0 { | |||||
return scanTypeUint32 | |||||
} | |||||
return scanTypeInt32 | |||||
} | |||||
return scanTypeNullInt | |||||
case fieldTypeLongLong: | |||||
if mf.flags&flagNotNULL != 0 { | |||||
if mf.flags&flagUnsigned != 0 { | |||||
return scanTypeUint64 | |||||
} | |||||
return scanTypeInt64 | |||||
} | |||||
return scanTypeNullInt | |||||
case fieldTypeFloat: | |||||
if mf.flags&flagNotNULL != 0 { | |||||
return scanTypeFloat32 | |||||
} | |||||
return scanTypeNullFloat | |||||
case fieldTypeDouble: | |||||
if mf.flags&flagNotNULL != 0 { | |||||
return scanTypeFloat64 | |||||
} | |||||
return scanTypeNullFloat | |||||
case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar, | |||||
fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB, | |||||
fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, | |||||
fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON, | |||||
fieldTypeTime: | |||||
return scanTypeRawBytes | |||||
case fieldTypeDate, fieldTypeNewDate, | |||||
fieldTypeTimestamp, fieldTypeDateTime: | |||||
// NullTime is always returned for more consistent behavior as it can | |||||
// handle both cases of parseTime regardless if the field is nullable. | |||||
return scanTypeNullTime | |||||
default: | |||||
return scanTypeUnknown | |||||
} | |||||
} |
@@ -147,7 +147,8 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { | |||||
} | } | ||||
// send content packets | // send content packets | ||||
if err == nil { | |||||
// if packetSize == 0, the Reader contains no data | |||||
if err == nil && packetSize > 0 { | |||||
data := make([]byte, 4+packetSize) | data := make([]byte, 4+packetSize) | ||||
var n int | var n int | ||||
for err == nil { | for err == nil { | ||||
@@ -173,8 +174,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { | |||||
// read OK packet | // read OK packet | ||||
if err == nil { | if err == nil { | ||||
_, err = mc.readResultOK() | |||||
return err | |||||
return mc.readResultOK() | |||||
} | } | ||||
mc.readPacket() | mc.readPacket() | ||||
@@ -25,26 +25,23 @@ import ( | |||||
// Read packet to buffer 'data' | // Read packet to buffer 'data' | ||||
func (mc *mysqlConn) readPacket() ([]byte, error) { | func (mc *mysqlConn) readPacket() ([]byte, error) { | ||||
var payload []byte | |||||
var prevData []byte | |||||
for { | for { | ||||
// Read packet header | |||||
// read packet header | |||||
data, err := mc.buf.readNext(4) | data, err := mc.buf.readNext(4) | ||||
if err != nil { | if err != nil { | ||||
if cerr := mc.canceled.Value(); cerr != nil { | |||||
return nil, cerr | |||||
} | |||||
errLog.Print(err) | errLog.Print(err) | ||||
mc.Close() | mc.Close() | ||||
return nil, driver.ErrBadConn | |||||
return nil, ErrInvalidConn | |||||
} | } | ||||
// Packet Length [24 bit] | |||||
// packet length [24 bit] | |||||
pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) | pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) | ||||
if pktLen < 1 { | |||||
errLog.Print(ErrMalformPkt) | |||||
mc.Close() | |||||
return nil, driver.ErrBadConn | |||||
} | |||||
// Check Packet Sync [8 bit] | |||||
// check packet sync [8 bit] | |||||
if data[3] != mc.sequence { | if data[3] != mc.sequence { | ||||
if data[3] > mc.sequence { | if data[3] > mc.sequence { | ||||
return nil, ErrPktSyncMul | return nil, ErrPktSyncMul | ||||
@@ -53,26 +50,41 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { | |||||
} | } | ||||
mc.sequence++ | mc.sequence++ | ||||
// Read packet body [pktLen bytes] | |||||
// packets with length 0 terminate a previous packet which is a | |||||
// multiple of (2^24)−1 bytes long | |||||
if pktLen == 0 { | |||||
// there was no previous packet | |||||
if prevData == nil { | |||||
errLog.Print(ErrMalformPkt) | |||||
mc.Close() | |||||
return nil, ErrInvalidConn | |||||
} | |||||
return prevData, nil | |||||
} | |||||
// read packet body [pktLen bytes] | |||||
data, err = mc.buf.readNext(pktLen) | data, err = mc.buf.readNext(pktLen) | ||||
if err != nil { | if err != nil { | ||||
if cerr := mc.canceled.Value(); cerr != nil { | |||||
return nil, cerr | |||||
} | |||||
errLog.Print(err) | errLog.Print(err) | ||||
mc.Close() | mc.Close() | ||||
return nil, driver.ErrBadConn | |||||
return nil, ErrInvalidConn | |||||
} | } | ||||
isLastPacket := (pktLen < maxPacketSize) | |||||
// return data if this was the last packet | |||||
if pktLen < maxPacketSize { | |||||
// zero allocations for non-split packets | |||||
if prevData == nil { | |||||
return data, nil | |||||
} | |||||
// Zero allocations for non-splitting packets | |||||
if isLastPacket && payload == nil { | |||||
return data, nil | |||||
return append(prevData, data...), nil | |||||
} | } | ||||
payload = append(payload, data...) | |||||
if isLastPacket { | |||||
return payload, nil | |||||
} | |||||
prevData = append(prevData, data...) | |||||
} | } | ||||
} | } | ||||
@@ -119,33 +131,47 @@ func (mc *mysqlConn) writePacket(data []byte) error { | |||||
// Handle error | // Handle error | ||||
if err == nil { // n != len(data) | if err == nil { // n != len(data) | ||||
mc.cleanup() | |||||
errLog.Print(ErrMalformPkt) | errLog.Print(ErrMalformPkt) | ||||
} else { | } else { | ||||
if cerr := mc.canceled.Value(); cerr != nil { | |||||
return cerr | |||||
} | |||||
if n == 0 && pktLen == len(data)-4 { | |||||
// only for the first loop iteration when nothing was written yet | |||||
return errBadConnNoWrite | |||||
} | |||||
mc.cleanup() | |||||
errLog.Print(err) | errLog.Print(err) | ||||
} | } | ||||
return driver.ErrBadConn | |||||
return ErrInvalidConn | |||||
} | } | ||||
} | } | ||||
/****************************************************************************** | /****************************************************************************** | ||||
* Initialisation Process * | |||||
* Initialization Process * | |||||
******************************************************************************/ | ******************************************************************************/ | ||||
// Handshake Initialization Packet | // Handshake Initialization Packet | ||||
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake | // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake | ||||
func (mc *mysqlConn) readInitPacket() ([]byte, error) { | |||||
func (mc *mysqlConn) readHandshakePacket() ([]byte, string, error) { | |||||
data, err := mc.readPacket() | data, err := mc.readPacket() | ||||
if err != nil { | if err != nil { | ||||
return nil, err | |||||
// for init we can rewrite this to ErrBadConn for sql.Driver to retry, since | |||||
// in connection initialization we don't risk retrying non-idempotent actions. | |||||
if err == ErrInvalidConn { | |||||
return nil, "", driver.ErrBadConn | |||||
} | |||||
return nil, "", err | |||||
} | } | ||||
if data[0] == iERR { | if data[0] == iERR { | ||||
return nil, mc.handleErrorPacket(data) | |||||
return nil, "", mc.handleErrorPacket(data) | |||||
} | } | ||||
// protocol version [1 byte] | // protocol version [1 byte] | ||||
if data[0] < minProtocolVersion { | if data[0] < minProtocolVersion { | ||||
return nil, fmt.Errorf( | |||||
return nil, "", fmt.Errorf( | |||||
"unsupported protocol version %d. Version %d or higher is required", | "unsupported protocol version %d. Version %d or higher is required", | ||||
data[0], | data[0], | ||||
minProtocolVersion, | minProtocolVersion, | ||||
@@ -157,7 +183,7 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { | |||||
pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4 | pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4 | ||||
// first part of the password cipher [8 bytes] | // first part of the password cipher [8 bytes] | ||||
cipher := data[pos : pos+8] | |||||
authData := data[pos : pos+8] | |||||
// (filler) always 0x00 [1 byte] | // (filler) always 0x00 [1 byte] | ||||
pos += 8 + 1 | pos += 8 + 1 | ||||
@@ -165,13 +191,14 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { | |||||
// capability flags (lower 2 bytes) [2 bytes] | // capability flags (lower 2 bytes) [2 bytes] | ||||
mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) | mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) | ||||
if mc.flags&clientProtocol41 == 0 { | if mc.flags&clientProtocol41 == 0 { | ||||
return nil, ErrOldProtocol | |||||
return nil, "", ErrOldProtocol | |||||
} | } | ||||
if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { | if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { | ||||
return nil, ErrNoTLS | |||||
return nil, "", ErrNoTLS | |||||
} | } | ||||
pos += 2 | pos += 2 | ||||
plugin := "" | |||||
if len(data) > pos { | if len(data) > pos { | ||||
// character set [1 byte] | // character set [1 byte] | ||||
// status flags [2 bytes] | // status flags [2 bytes] | ||||
@@ -192,32 +219,34 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { | |||||
// | // | ||||
// The official Python library uses the fixed length 12 | // The official Python library uses the fixed length 12 | ||||
// which seems to work but technically could have a hidden bug. | // which seems to work but technically could have a hidden bug. | ||||
cipher = append(cipher, data[pos:pos+12]...) | |||||
authData = append(authData, data[pos:pos+12]...) | |||||
pos += 13 | |||||
// TODO: Verify string termination | |||||
// EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2) | // EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2) | ||||
// \NUL otherwise | // \NUL otherwise | ||||
// | |||||
//if data[len(data)-1] == 0 { | |||||
// return | |||||
//} | |||||
//return ErrMalformPkt | |||||
if end := bytes.IndexByte(data[pos:], 0x00); end != -1 { | |||||
plugin = string(data[pos : pos+end]) | |||||
} else { | |||||
plugin = string(data[pos:]) | |||||
} | |||||
// make a memory safe copy of the cipher slice | // make a memory safe copy of the cipher slice | ||||
var b [20]byte | var b [20]byte | ||||
copy(b[:], cipher) | |||||
return b[:], nil | |||||
copy(b[:], authData) | |||||
return b[:], plugin, nil | |||||
} | } | ||||
plugin = defaultAuthPlugin | |||||
// make a memory safe copy of the cipher slice | // make a memory safe copy of the cipher slice | ||||
var b [8]byte | var b [8]byte | ||||
copy(b[:], cipher) | |||||
return b[:], nil | |||||
copy(b[:], authData) | |||||
return b[:], plugin, nil | |||||
} | } | ||||
// Client Authentication Packet | // Client Authentication Packet | ||||
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse | // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse | ||||
func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { | |||||
func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, plugin string) error { | |||||
// Adjust client flags based on server support | // Adjust client flags based on server support | ||||
clientFlags := clientProtocol41 | | clientFlags := clientProtocol41 | | ||||
clientSecureConn | | clientSecureConn | | ||||
@@ -241,10 +270,19 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { | |||||
clientFlags |= clientMultiStatements | clientFlags |= clientMultiStatements | ||||
} | } | ||||
// User Password | |||||
scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd)) | |||||
// encode length of the auth plugin data | |||||
var authRespLEIBuf [9]byte | |||||
authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(authResp))) | |||||
if len(authRespLEI) > 1 { | |||||
// if the length can not be written in 1 byte, it must be written as a | |||||
// length encoded integer | |||||
clientFlags |= clientPluginAuthLenEncClientData | |||||
} | |||||
pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(scrambleBuff) + 21 + 1 | |||||
pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1 | |||||
if addNUL { | |||||
pktLen++ | |||||
} | |||||
// To specify a db name | // To specify a db name | ||||
if n := len(mc.cfg.DBName); n > 0 { | if n := len(mc.cfg.DBName); n > 0 { | ||||
@@ -255,9 +293,9 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { | |||||
// Calculate packet length and get buffer with that size | // Calculate packet length and get buffer with that size | ||||
data := mc.buf.takeSmallBuffer(pktLen + 4) | data := mc.buf.takeSmallBuffer(pktLen + 4) | ||||
if data == nil { | if data == nil { | ||||
// can not take the buffer. Something must be wrong with the connection | |||||
// cannot take the buffer. Something must be wrong with the connection | |||||
errLog.Print(ErrBusyBuffer) | errLog.Print(ErrBusyBuffer) | ||||
return driver.ErrBadConn | |||||
return errBadConnNoWrite | |||||
} | } | ||||
// ClientFlags [32 bit] | // ClientFlags [32 bit] | ||||
@@ -312,9 +350,13 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { | |||||
data[pos] = 0x00 | data[pos] = 0x00 | ||||
pos++ | pos++ | ||||
// ScrambleBuffer [length encoded integer] | |||||
data[pos] = byte(len(scrambleBuff)) | |||||
pos += 1 + copy(data[pos+1:], scrambleBuff) | |||||
// Auth Data [length encoded integer] | |||||
pos += copy(data[pos:], authRespLEI) | |||||
pos += copy(data[pos:], authResp) | |||||
if addNUL { | |||||
data[pos] = 0x00 | |||||
pos++ | |||||
} | |||||
// Databasename [null terminated string] | // Databasename [null terminated string] | ||||
if len(mc.cfg.DBName) > 0 { | if len(mc.cfg.DBName) > 0 { | ||||
@@ -323,72 +365,32 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { | |||||
pos++ | pos++ | ||||
} | } | ||||
// Assume native client during response | |||||
pos += copy(data[pos:], "mysql_native_password") | |||||
pos += copy(data[pos:], plugin) | |||||
data[pos] = 0x00 | data[pos] = 0x00 | ||||
// Send Auth packet | // Send Auth packet | ||||
return mc.writePacket(data) | return mc.writePacket(data) | ||||
} | } | ||||
// Client old authentication packet | |||||
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse | // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse | ||||
func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { | |||||
// User password | |||||
scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.Passwd)) | |||||
// Calculate the packet length and add a tailing 0 | |||||
pktLen := len(scrambleBuff) + 1 | |||||
data := mc.buf.takeSmallBuffer(4 + pktLen) | |||||
if data == nil { | |||||
// can not take the buffer. Something must be wrong with the connection | |||||
errLog.Print(ErrBusyBuffer) | |||||
return driver.ErrBadConn | |||||
func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte, addNUL bool) error { | |||||
pktLen := 4 + len(authData) | |||||
if addNUL { | |||||
pktLen++ | |||||
} | } | ||||
// Add the scrambled password [null terminated string] | |||||
copy(data[4:], scrambleBuff) | |||||
data[4+pktLen-1] = 0x00 | |||||
return mc.writePacket(data) | |||||
} | |||||
// Client clear text authentication packet | |||||
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse | |||||
func (mc *mysqlConn) writeClearAuthPacket() error { | |||||
// Calculate the packet length and add a tailing 0 | |||||
pktLen := len(mc.cfg.Passwd) + 1 | |||||
data := mc.buf.takeSmallBuffer(4 + pktLen) | |||||
data := mc.buf.takeSmallBuffer(pktLen) | |||||
if data == nil { | if data == nil { | ||||
// can not take the buffer. Something must be wrong with the connection | |||||
// cannot take the buffer. Something must be wrong with the connection | |||||
errLog.Print(ErrBusyBuffer) | errLog.Print(ErrBusyBuffer) | ||||
return driver.ErrBadConn | |||||
return errBadConnNoWrite | |||||
} | } | ||||
// Add the clear password [null terminated string] | |||||
copy(data[4:], mc.cfg.Passwd) | |||||
data[4+pktLen-1] = 0x00 | |||||
return mc.writePacket(data) | |||||
} | |||||
// Native password authentication method | |||||
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse | |||||
func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error { | |||||
scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd)) | |||||
// Calculate the packet length and add a tailing 0 | |||||
pktLen := len(scrambleBuff) | |||||
data := mc.buf.takeSmallBuffer(4 + pktLen) | |||||
if data == nil { | |||||
// can not take the buffer. Something must be wrong with the connection | |||||
errLog.Print(ErrBusyBuffer) | |||||
return driver.ErrBadConn | |||||
// Add the auth data [EOF] | |||||
copy(data[4:], authData) | |||||
if addNUL { | |||||
data[pktLen-1] = 0x00 | |||||
} | } | ||||
// Add the scramble | |||||
copy(data[4:], scrambleBuff) | |||||
return mc.writePacket(data) | return mc.writePacket(data) | ||||
} | } | ||||
@@ -402,9 +404,9 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { | |||||
data := mc.buf.takeSmallBuffer(4 + 1) | data := mc.buf.takeSmallBuffer(4 + 1) | ||||
if data == nil { | if data == nil { | ||||
// can not take the buffer. Something must be wrong with the connection | |||||
// cannot take the buffer. Something must be wrong with the connection | |||||
errLog.Print(ErrBusyBuffer) | errLog.Print(ErrBusyBuffer) | ||||
return driver.ErrBadConn | |||||
return errBadConnNoWrite | |||||
} | } | ||||
// Add command byte | // Add command byte | ||||
@@ -421,9 +423,9 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { | |||||
pktLen := 1 + len(arg) | pktLen := 1 + len(arg) | ||||
data := mc.buf.takeBuffer(pktLen + 4) | data := mc.buf.takeBuffer(pktLen + 4) | ||||
if data == nil { | if data == nil { | ||||
// can not take the buffer. Something must be wrong with the connection | |||||
// cannot take the buffer. Something must be wrong with the connection | |||||
errLog.Print(ErrBusyBuffer) | errLog.Print(ErrBusyBuffer) | ||||
return driver.ErrBadConn | |||||
return errBadConnNoWrite | |||||
} | } | ||||
// Add command byte | // Add command byte | ||||
@@ -442,9 +444,9 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { | |||||
data := mc.buf.takeSmallBuffer(4 + 1 + 4) | data := mc.buf.takeSmallBuffer(4 + 1 + 4) | ||||
if data == nil { | if data == nil { | ||||
// can not take the buffer. Something must be wrong with the connection | |||||
// cannot take the buffer. Something must be wrong with the connection | |||||
errLog.Print(ErrBusyBuffer) | errLog.Print(ErrBusyBuffer) | ||||
return driver.ErrBadConn | |||||
return errBadConnNoWrite | |||||
} | } | ||||
// Add command byte | // Add command byte | ||||
@@ -464,43 +466,50 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { | |||||
* Result Packets * | * Result Packets * | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
// Returns error if Packet is not an 'Result OK'-Packet | |||||
func (mc *mysqlConn) readResultOK() ([]byte, error) { | |||||
func (mc *mysqlConn) readAuthResult() ([]byte, string, error) { | |||||
data, err := mc.readPacket() | data, err := mc.readPacket() | ||||
if err == nil { | |||||
// packet indicator | |||||
switch data[0] { | |||||
if err != nil { | |||||
return nil, "", err | |||||
} | |||||
case iOK: | |||||
return nil, mc.handleOkPacket(data) | |||||
// packet indicator | |||||
switch data[0] { | |||||
case iEOF: | |||||
if len(data) > 1 { | |||||
pluginEndIndex := bytes.IndexByte(data, 0x00) | |||||
plugin := string(data[1:pluginEndIndex]) | |||||
cipher := data[pluginEndIndex+1 : len(data)-1] | |||||
if plugin == "mysql_old_password" { | |||||
// using old_passwords | |||||
return cipher, ErrOldPassword | |||||
} else if plugin == "mysql_clear_password" { | |||||
// using clear text password | |||||
return cipher, ErrCleartextPassword | |||||
} else if plugin == "mysql_native_password" { | |||||
// using mysql default authentication method | |||||
return cipher, ErrNativePassword | |||||
} else { | |||||
return cipher, ErrUnknownPlugin | |||||
} | |||||
} else { | |||||
return nil, ErrOldPassword | |||||
} | |||||
case iOK: | |||||
return nil, "", mc.handleOkPacket(data) | |||||
default: // Error otherwise | |||||
return nil, mc.handleErrorPacket(data) | |||||
case iAuthMoreData: | |||||
return data[1:], "", err | |||||
case iEOF: | |||||
if len(data) < 1 { | |||||
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest | |||||
return nil, "mysql_old_password", nil | |||||
} | |||||
pluginEndIndex := bytes.IndexByte(data, 0x00) | |||||
if pluginEndIndex < 0 { | |||||
return nil, "", ErrMalformPkt | |||||
} | } | ||||
plugin := string(data[1:pluginEndIndex]) | |||||
authData := data[pluginEndIndex+1:] | |||||
return authData, plugin, nil | |||||
default: // Error otherwise | |||||
return nil, "", mc.handleErrorPacket(data) | |||||
} | } | ||||
return nil, err | |||||
} | |||||
// Returns error if Packet is not an 'Result OK'-Packet | |||||
func (mc *mysqlConn) readResultOK() error { | |||||
data, err := mc.readPacket() | |||||
if err != nil { | |||||
return err | |||||
} | |||||
if data[0] == iOK { | |||||
return mc.handleOkPacket(data) | |||||
} | |||||
return mc.handleErrorPacket(data) | |||||
} | } | ||||
// Result Set Header Packet | // Result Set Header Packet | ||||
@@ -543,6 +552,22 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error { | |||||
// Error Number [16 bit uint] | // Error Number [16 bit uint] | ||||
errno := binary.LittleEndian.Uint16(data[1:3]) | errno := binary.LittleEndian.Uint16(data[1:3]) | ||||
// 1792: ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION | |||||
// 1290: ER_OPTION_PREVENTS_STATEMENT (returned by Aurora during failover) | |||||
if (errno == 1792 || errno == 1290) && mc.cfg.RejectReadOnly { | |||||
// Oops; we are connected to a read-only connection, and won't be able | |||||
// to issue any write statements. Since RejectReadOnly is configured, | |||||
// we throw away this connection hoping this one would have write | |||||
// permission. This is specifically for a possible race condition | |||||
// during failover (e.g. on AWS Aurora). See README.md for more. | |||||
// | |||||
// We explicitly close the connection before returning | |||||
// driver.ErrBadConn to ensure that `database/sql` purges this | |||||
// connection and initiates a new one for next statement next time. | |||||
mc.Close() | |||||
return driver.ErrBadConn | |||||
} | |||||
pos := 3 | pos := 3 | ||||
// SQL State [optional: # + 5bytes string] | // SQL State [optional: # + 5bytes string] | ||||
@@ -577,19 +602,12 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error { | |||||
// server_status [2 bytes] | // server_status [2 bytes] | ||||
mc.status = readStatus(data[1+n+m : 1+n+m+2]) | mc.status = readStatus(data[1+n+m : 1+n+m+2]) | ||||
if err := mc.discardResults(); err != nil { | |||||
return err | |||||
if mc.status&statusMoreResultsExists != 0 { | |||||
return nil | |||||
} | } | ||||
// warning count [2 bytes] | // warning count [2 bytes] | ||||
if !mc.strict { | |||||
return nil | |||||
} | |||||
pos := 1 + n + m + 2 | |||||
if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 { | |||||
return mc.getWarnings() | |||||
} | |||||
return nil | return nil | ||||
} | } | ||||
@@ -661,14 +679,21 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { | |||||
if err != nil { | if err != nil { | ||||
return nil, err | return nil, err | ||||
} | } | ||||
pos += n | |||||
// Filler [uint8] | // Filler [uint8] | ||||
pos++ | |||||
// Charset [charset, collation uint8] | // Charset [charset, collation uint8] | ||||
columns[i].charSet = data[pos] | |||||
pos += 2 | |||||
// Length [uint32] | // Length [uint32] | ||||
pos += n + 1 + 2 + 4 | |||||
columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4]) | |||||
pos += 4 | |||||
// Field type [uint8] | // Field type [uint8] | ||||
columns[i].fieldType = data[pos] | |||||
columns[i].fieldType = fieldType(data[pos]) | |||||
pos++ | pos++ | ||||
// Flags [uint16] | // Flags [uint16] | ||||
@@ -691,6 +716,10 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { | |||||
func (rows *textRows) readRow(dest []driver.Value) error { | func (rows *textRows) readRow(dest []driver.Value) error { | ||||
mc := rows.mc | mc := rows.mc | ||||
if rows.rs.done { | |||||
return io.EOF | |||||
} | |||||
data, err := mc.readPacket() | data, err := mc.readPacket() | ||||
if err != nil { | if err != nil { | ||||
return err | return err | ||||
@@ -700,10 +729,10 @@ func (rows *textRows) readRow(dest []driver.Value) error { | |||||
if data[0] == iEOF && len(data) == 5 { | if data[0] == iEOF && len(data) == 5 { | ||||
// server_status [2 bytes] | // server_status [2 bytes] | ||||
rows.mc.status = readStatus(data[3:]) | rows.mc.status = readStatus(data[3:]) | ||||
if err := rows.mc.discardResults(); err != nil { | |||||
return err | |||||
rows.rs.done = true | |||||
if !rows.HasNextResultSet() { | |||||
rows.mc = nil | |||||
} | } | ||||
rows.mc = nil | |||||
return io.EOF | return io.EOF | ||||
} | } | ||||
if data[0] == iERR { | if data[0] == iERR { | ||||
@@ -725,7 +754,7 @@ func (rows *textRows) readRow(dest []driver.Value) error { | |||||
if !mc.parseTime { | if !mc.parseTime { | ||||
continue | continue | ||||
} else { | } else { | ||||
switch rows.columns[i].fieldType { | |||||
switch rows.rs.columns[i].fieldType { | |||||
case fieldTypeTimestamp, fieldTypeDateTime, | case fieldTypeTimestamp, fieldTypeDateTime, | ||||
fieldTypeDate, fieldTypeNewDate: | fieldTypeDate, fieldTypeNewDate: | ||||
dest[i], err = parseDateTime( | dest[i], err = parseDateTime( | ||||
@@ -797,14 +826,7 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { | |||||
// Reserved [8 bit] | // Reserved [8 bit] | ||||
// Warning count [16 bit uint] | // Warning count [16 bit uint] | ||||
if !stmt.mc.strict { | |||||
return columnCount, nil | |||||
} | |||||
// Check for warnings count > 0, only available in MySQL > 4.1 | |||||
if len(data) >= 12 && binary.LittleEndian.Uint16(data[10:12]) > 0 { | |||||
return columnCount, stmt.mc.getWarnings() | |||||
} | |||||
return columnCount, nil | return columnCount, nil | ||||
} | } | ||||
return 0, err | return 0, err | ||||
@@ -821,7 +843,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { | |||||
// 2 bytes paramID | // 2 bytes paramID | ||||
const dataOffset = 1 + 4 + 2 | const dataOffset = 1 + 4 + 2 | ||||
// Can not use the write buffer since | |||||
// Cannot use the write buffer since | |||||
// a) the buffer is too small | // a) the buffer is too small | ||||
// b) it is in use | // b) it is in use | ||||
data := make([]byte, 4+1+4+2+len(arg)) | data := make([]byte, 4+1+4+2+len(arg)) | ||||
@@ -876,6 +898,12 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | |||||
const minPktLen = 4 + 1 + 4 + 1 + 4 | const minPktLen = 4 + 1 + 4 + 1 + 4 | ||||
mc := stmt.mc | mc := stmt.mc | ||||
// Determine threshould dynamically to avoid packet size shortage. | |||||
longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1) | |||||
if longDataSize < 64 { | |||||
longDataSize = 64 | |||||
} | |||||
// Reset packet-sequence | // Reset packet-sequence | ||||
mc.sequence = 0 | mc.sequence = 0 | ||||
@@ -887,9 +915,9 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | |||||
data = mc.buf.takeCompleteBuffer() | data = mc.buf.takeCompleteBuffer() | ||||
} | } | ||||
if data == nil { | if data == nil { | ||||
// can not take the buffer. Something must be wrong with the connection | |||||
// cannot take the buffer. Something must be wrong with the connection | |||||
errLog.Print(ErrBusyBuffer) | errLog.Print(ErrBusyBuffer) | ||||
return driver.ErrBadConn | |||||
return errBadConnNoWrite | |||||
} | } | ||||
// command [1 byte] | // command [1 byte] | ||||
@@ -948,7 +976,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | |||||
// build NULL-bitmap | // build NULL-bitmap | ||||
if arg == nil { | if arg == nil { | ||||
nullMask[i/8] |= 1 << (uint(i) & 7) | nullMask[i/8] |= 1 << (uint(i) & 7) | ||||
paramTypes[i+i] = fieldTypeNULL | |||||
paramTypes[i+i] = byte(fieldTypeNULL) | |||||
paramTypes[i+i+1] = 0x00 | paramTypes[i+i+1] = 0x00 | ||||
continue | continue | ||||
} | } | ||||
@@ -956,7 +984,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | |||||
// cache types and values | // cache types and values | ||||
switch v := arg.(type) { | switch v := arg.(type) { | ||||
case int64: | case int64: | ||||
paramTypes[i+i] = fieldTypeLongLong | |||||
paramTypes[i+i] = byte(fieldTypeLongLong) | |||||
paramTypes[i+i+1] = 0x00 | paramTypes[i+i+1] = 0x00 | ||||
if cap(paramValues)-len(paramValues)-8 >= 0 { | if cap(paramValues)-len(paramValues)-8 >= 0 { | ||||
@@ -972,7 +1000,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | |||||
} | } | ||||
case float64: | case float64: | ||||
paramTypes[i+i] = fieldTypeDouble | |||||
paramTypes[i+i] = byte(fieldTypeDouble) | |||||
paramTypes[i+i+1] = 0x00 | paramTypes[i+i+1] = 0x00 | ||||
if cap(paramValues)-len(paramValues)-8 >= 0 { | if cap(paramValues)-len(paramValues)-8 >= 0 { | ||||
@@ -988,7 +1016,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | |||||
} | } | ||||
case bool: | case bool: | ||||
paramTypes[i+i] = fieldTypeTiny | |||||
paramTypes[i+i] = byte(fieldTypeTiny) | |||||
paramTypes[i+i+1] = 0x00 | paramTypes[i+i+1] = 0x00 | ||||
if v { | if v { | ||||
@@ -1000,10 +1028,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | |||||
case []byte: | case []byte: | ||||
// Common case (non-nil value) first | // Common case (non-nil value) first | ||||
if v != nil { | if v != nil { | ||||
paramTypes[i+i] = fieldTypeString | |||||
paramTypes[i+i] = byte(fieldTypeString) | |||||
paramTypes[i+i+1] = 0x00 | paramTypes[i+i+1] = 0x00 | ||||
if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 { | |||||
if len(v) < longDataSize { | |||||
paramValues = appendLengthEncodedInteger(paramValues, | paramValues = appendLengthEncodedInteger(paramValues, | ||||
uint64(len(v)), | uint64(len(v)), | ||||
) | ) | ||||
@@ -1018,14 +1046,14 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | |||||
// Handle []byte(nil) as a NULL value | // Handle []byte(nil) as a NULL value | ||||
nullMask[i/8] |= 1 << (uint(i) & 7) | nullMask[i/8] |= 1 << (uint(i) & 7) | ||||
paramTypes[i+i] = fieldTypeNULL | |||||
paramTypes[i+i] = byte(fieldTypeNULL) | |||||
paramTypes[i+i+1] = 0x00 | paramTypes[i+i+1] = 0x00 | ||||
case string: | case string: | ||||
paramTypes[i+i] = fieldTypeString | |||||
paramTypes[i+i] = byte(fieldTypeString) | |||||
paramTypes[i+i+1] = 0x00 | paramTypes[i+i+1] = 0x00 | ||||
if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 { | |||||
if len(v) < longDataSize { | |||||
paramValues = appendLengthEncodedInteger(paramValues, | paramValues = appendLengthEncodedInteger(paramValues, | ||||
uint64(len(v)), | uint64(len(v)), | ||||
) | ) | ||||
@@ -1037,23 +1065,25 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | |||||
} | } | ||||
case time.Time: | case time.Time: | ||||
paramTypes[i+i] = fieldTypeString | |||||
paramTypes[i+i] = byte(fieldTypeString) | |||||
paramTypes[i+i+1] = 0x00 | paramTypes[i+i+1] = 0x00 | ||||
var val []byte | |||||
var a [64]byte | |||||
var b = a[:0] | |||||
if v.IsZero() { | if v.IsZero() { | ||||
val = []byte("0000-00-00") | |||||
b = append(b, "0000-00-00"...) | |||||
} else { | } else { | ||||
val = []byte(v.In(mc.cfg.Loc).Format(timeFormat)) | |||||
b = v.In(mc.cfg.Loc).AppendFormat(b, timeFormat) | |||||
} | } | ||||
paramValues = appendLengthEncodedInteger(paramValues, | paramValues = appendLengthEncodedInteger(paramValues, | ||||
uint64(len(val)), | |||||
uint64(len(b)), | |||||
) | ) | ||||
paramValues = append(paramValues, val...) | |||||
paramValues = append(paramValues, b...) | |||||
default: | default: | ||||
return fmt.Errorf("can not convert type: %T", arg) | |||||
return fmt.Errorf("cannot convert type: %T", arg) | |||||
} | } | ||||
} | } | ||||
@@ -1086,8 +1116,6 @@ func (mc *mysqlConn) discardResults() error { | |||||
if err := mc.readUntilEOF(); err != nil { | if err := mc.readUntilEOF(); err != nil { | ||||
return err | return err | ||||
} | } | ||||
} else { | |||||
mc.status &^= statusMoreResultsExists | |||||
} | } | ||||
} | } | ||||
return nil | return nil | ||||
@@ -1105,16 +1133,17 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | |||||
// EOF Packet | // EOF Packet | ||||
if data[0] == iEOF && len(data) == 5 { | if data[0] == iEOF && len(data) == 5 { | ||||
rows.mc.status = readStatus(data[3:]) | rows.mc.status = readStatus(data[3:]) | ||||
if err := rows.mc.discardResults(); err != nil { | |||||
return err | |||||
rows.rs.done = true | |||||
if !rows.HasNextResultSet() { | |||||
rows.mc = nil | |||||
} | } | ||||
rows.mc = nil | |||||
return io.EOF | return io.EOF | ||||
} | } | ||||
mc := rows.mc | |||||
rows.mc = nil | rows.mc = nil | ||||
// Error otherwise | // Error otherwise | ||||
return rows.mc.handleErrorPacket(data) | |||||
return mc.handleErrorPacket(data) | |||||
} | } | ||||
// NULL-bitmap, [(column-count + 7 + 2) / 8 bytes] | // NULL-bitmap, [(column-count + 7 + 2) / 8 bytes] | ||||
@@ -1130,14 +1159,14 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | |||||
} | } | ||||
// Convert to byte-coded string | // Convert to byte-coded string | ||||
switch rows.columns[i].fieldType { | |||||
switch rows.rs.columns[i].fieldType { | |||||
case fieldTypeNULL: | case fieldTypeNULL: | ||||
dest[i] = nil | dest[i] = nil | ||||
continue | continue | ||||
// Numeric Types | // Numeric Types | ||||
case fieldTypeTiny: | case fieldTypeTiny: | ||||
if rows.columns[i].flags&flagUnsigned != 0 { | |||||
if rows.rs.columns[i].flags&flagUnsigned != 0 { | |||||
dest[i] = int64(data[pos]) | dest[i] = int64(data[pos]) | ||||
} else { | } else { | ||||
dest[i] = int64(int8(data[pos])) | dest[i] = int64(int8(data[pos])) | ||||
@@ -1146,7 +1175,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | |||||
continue | continue | ||||
case fieldTypeShort, fieldTypeYear: | case fieldTypeShort, fieldTypeYear: | ||||
if rows.columns[i].flags&flagUnsigned != 0 { | |||||
if rows.rs.columns[i].flags&flagUnsigned != 0 { | |||||
dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2])) | dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2])) | ||||
} else { | } else { | ||||
dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2]))) | dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2]))) | ||||
@@ -1155,7 +1184,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | |||||
continue | continue | ||||
case fieldTypeInt24, fieldTypeLong: | case fieldTypeInt24, fieldTypeLong: | ||||
if rows.columns[i].flags&flagUnsigned != 0 { | |||||
if rows.rs.columns[i].flags&flagUnsigned != 0 { | |||||
dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4])) | dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4])) | ||||
} else { | } else { | ||||
dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4]))) | dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4]))) | ||||
@@ -1164,7 +1193,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | |||||
continue | continue | ||||
case fieldTypeLongLong: | case fieldTypeLongLong: | ||||
if rows.columns[i].flags&flagUnsigned != 0 { | |||||
if rows.rs.columns[i].flags&flagUnsigned != 0 { | |||||
val := binary.LittleEndian.Uint64(data[pos : pos+8]) | val := binary.LittleEndian.Uint64(data[pos : pos+8]) | ||||
if val > math.MaxInt64 { | if val > math.MaxInt64 { | ||||
dest[i] = uint64ToString(val) | dest[i] = uint64ToString(val) | ||||
@@ -1178,7 +1207,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | |||||
continue | continue | ||||
case fieldTypeFloat: | case fieldTypeFloat: | ||||
dest[i] = float32(math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4]))) | |||||
dest[i] = math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4])) | |||||
pos += 4 | pos += 4 | ||||
continue | continue | ||||
@@ -1218,10 +1247,10 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | |||||
case isNull: | case isNull: | ||||
dest[i] = nil | dest[i] = nil | ||||
continue | continue | ||||
case rows.columns[i].fieldType == fieldTypeTime: | |||||
case rows.rs.columns[i].fieldType == fieldTypeTime: | |||||
// database/sql does not support an equivalent to TIME, return a string | // database/sql does not support an equivalent to TIME, return a string | ||||
var dstlen uint8 | var dstlen uint8 | ||||
switch decimals := rows.columns[i].decimals; decimals { | |||||
switch decimals := rows.rs.columns[i].decimals; decimals { | |||||
case 0x00, 0x1f: | case 0x00, 0x1f: | ||||
dstlen = 8 | dstlen = 8 | ||||
case 1, 2, 3, 4, 5, 6: | case 1, 2, 3, 4, 5, 6: | ||||
@@ -1229,7 +1258,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | |||||
default: | default: | ||||
return fmt.Errorf( | return fmt.Errorf( | ||||
"protocol error, illegal decimals value %d", | "protocol error, illegal decimals value %d", | ||||
rows.columns[i].decimals, | |||||
rows.rs.columns[i].decimals, | |||||
) | ) | ||||
} | } | ||||
dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true) | dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true) | ||||
@@ -1237,10 +1266,10 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | |||||
dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc) | dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc) | ||||
default: | default: | ||||
var dstlen uint8 | var dstlen uint8 | ||||
if rows.columns[i].fieldType == fieldTypeDate { | |||||
if rows.rs.columns[i].fieldType == fieldTypeDate { | |||||
dstlen = 10 | dstlen = 10 | ||||
} else { | } else { | ||||
switch decimals := rows.columns[i].decimals; decimals { | |||||
switch decimals := rows.rs.columns[i].decimals; decimals { | |||||
case 0x00, 0x1f: | case 0x00, 0x1f: | ||||
dstlen = 19 | dstlen = 19 | ||||
case 1, 2, 3, 4, 5, 6: | case 1, 2, 3, 4, 5, 6: | ||||
@@ -1248,7 +1277,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | |||||
default: | default: | ||||
return fmt.Errorf( | return fmt.Errorf( | ||||
"protocol error, illegal decimals value %d", | "protocol error, illegal decimals value %d", | ||||
rows.columns[i].decimals, | |||||
rows.rs.columns[i].decimals, | |||||
) | ) | ||||
} | } | ||||
} | } | ||||
@@ -1264,7 +1293,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | |||||
// Please report if this happens! | // Please report if this happens! | ||||
default: | default: | ||||
return fmt.Errorf("unknown field type %d", rows.columns[i].fieldType) | |||||
return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType) | |||||
} | } | ||||
} | } | ||||
@@ -11,19 +11,20 @@ package mysql | |||||
import ( | import ( | ||||
"database/sql/driver" | "database/sql/driver" | ||||
"io" | "io" | ||||
"math" | |||||
"reflect" | |||||
) | ) | ||||
type mysqlField struct { | |||||
tableName string | |||||
name string | |||||
flags fieldFlag | |||||
fieldType byte | |||||
decimals byte | |||||
type resultSet struct { | |||||
columns []mysqlField | |||||
columnNames []string | |||||
done bool | |||||
} | } | ||||
type mysqlRows struct { | type mysqlRows struct { | ||||
mc *mysqlConn | |||||
columns []mysqlField | |||||
mc *mysqlConn | |||||
rs resultSet | |||||
finish func() | |||||
} | } | ||||
type binaryRows struct { | type binaryRows struct { | ||||
@@ -34,37 +35,86 @@ type textRows struct { | |||||
mysqlRows | mysqlRows | ||||
} | } | ||||
type emptyRows struct{} | |||||
func (rows *mysqlRows) Columns() []string { | func (rows *mysqlRows) Columns() []string { | ||||
columns := make([]string, len(rows.columns)) | |||||
if rows.rs.columnNames != nil { | |||||
return rows.rs.columnNames | |||||
} | |||||
columns := make([]string, len(rows.rs.columns)) | |||||
if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias { | if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias { | ||||
for i := range columns { | for i := range columns { | ||||
if tableName := rows.columns[i].tableName; len(tableName) > 0 { | |||||
columns[i] = tableName + "." + rows.columns[i].name | |||||
if tableName := rows.rs.columns[i].tableName; len(tableName) > 0 { | |||||
columns[i] = tableName + "." + rows.rs.columns[i].name | |||||
} else { | } else { | ||||
columns[i] = rows.columns[i].name | |||||
columns[i] = rows.rs.columns[i].name | |||||
} | } | ||||
} | } | ||||
} else { | } else { | ||||
for i := range columns { | for i := range columns { | ||||
columns[i] = rows.columns[i].name | |||||
columns[i] = rows.rs.columns[i].name | |||||
} | } | ||||
} | } | ||||
rows.rs.columnNames = columns | |||||
return columns | return columns | ||||
} | } | ||||
func (rows *mysqlRows) Close() error { | |||||
func (rows *mysqlRows) ColumnTypeDatabaseTypeName(i int) string { | |||||
return rows.rs.columns[i].typeDatabaseName() | |||||
} | |||||
// func (rows *mysqlRows) ColumnTypeLength(i int) (length int64, ok bool) { | |||||
// return int64(rows.rs.columns[i].length), true | |||||
// } | |||||
func (rows *mysqlRows) ColumnTypeNullable(i int) (nullable, ok bool) { | |||||
return rows.rs.columns[i].flags&flagNotNULL == 0, true | |||||
} | |||||
func (rows *mysqlRows) ColumnTypePrecisionScale(i int) (int64, int64, bool) { | |||||
column := rows.rs.columns[i] | |||||
decimals := int64(column.decimals) | |||||
switch column.fieldType { | |||||
case fieldTypeDecimal, fieldTypeNewDecimal: | |||||
if decimals > 0 { | |||||
return int64(column.length) - 2, decimals, true | |||||
} | |||||
return int64(column.length) - 1, decimals, true | |||||
case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeTime: | |||||
return decimals, decimals, true | |||||
case fieldTypeFloat, fieldTypeDouble: | |||||
if decimals == 0x1f { | |||||
return math.MaxInt64, math.MaxInt64, true | |||||
} | |||||
return math.MaxInt64, decimals, true | |||||
} | |||||
return 0, 0, false | |||||
} | |||||
func (rows *mysqlRows) ColumnTypeScanType(i int) reflect.Type { | |||||
return rows.rs.columns[i].scanType() | |||||
} | |||||
func (rows *mysqlRows) Close() (err error) { | |||||
if f := rows.finish; f != nil { | |||||
f() | |||||
rows.finish = nil | |||||
} | |||||
mc := rows.mc | mc := rows.mc | ||||
if mc == nil { | if mc == nil { | ||||
return nil | return nil | ||||
} | } | ||||
if mc.netConn == nil { | |||||
return ErrInvalidConn | |||||
if err := mc.error(); err != nil { | |||||
return err | |||||
} | } | ||||
// Remove unread packets from stream | // Remove unread packets from stream | ||||
err := mc.readUntilEOF() | |||||
if !rows.rs.done { | |||||
err = mc.readUntilEOF() | |||||
} | |||||
if err == nil { | if err == nil { | ||||
if err = mc.discardResults(); err != nil { | if err = mc.discardResults(); err != nil { | ||||
return err | return err | ||||
@@ -75,22 +125,66 @@ func (rows *mysqlRows) Close() error { | |||||
return err | return err | ||||
} | } | ||||
func (rows *binaryRows) Next(dest []driver.Value) error { | |||||
if mc := rows.mc; mc != nil { | |||||
if mc.netConn == nil { | |||||
return ErrInvalidConn | |||||
func (rows *mysqlRows) HasNextResultSet() (b bool) { | |||||
if rows.mc == nil { | |||||
return false | |||||
} | |||||
return rows.mc.status&statusMoreResultsExists != 0 | |||||
} | |||||
func (rows *mysqlRows) nextResultSet() (int, error) { | |||||
if rows.mc == nil { | |||||
return 0, io.EOF | |||||
} | |||||
if err := rows.mc.error(); err != nil { | |||||
return 0, err | |||||
} | |||||
// Remove unread packets from stream | |||||
if !rows.rs.done { | |||||
if err := rows.mc.readUntilEOF(); err != nil { | |||||
return 0, err | |||||
} | } | ||||
rows.rs.done = true | |||||
} | |||||
// Fetch next row from stream | |||||
return rows.readRow(dest) | |||||
if !rows.HasNextResultSet() { | |||||
rows.mc = nil | |||||
return 0, io.EOF | |||||
} | } | ||||
return io.EOF | |||||
rows.rs = resultSet{} | |||||
return rows.mc.readResultSetHeaderPacket() | |||||
} | } | ||||
func (rows *textRows) Next(dest []driver.Value) error { | |||||
func (rows *mysqlRows) nextNotEmptyResultSet() (int, error) { | |||||
for { | |||||
resLen, err := rows.nextResultSet() | |||||
if err != nil { | |||||
return 0, err | |||||
} | |||||
if resLen > 0 { | |||||
return resLen, nil | |||||
} | |||||
rows.rs.done = true | |||||
} | |||||
} | |||||
func (rows *binaryRows) NextResultSet() error { | |||||
resLen, err := rows.nextNotEmptyResultSet() | |||||
if err != nil { | |||||
return err | |||||
} | |||||
rows.rs.columns, err = rows.mc.readColumns(resLen) | |||||
return err | |||||
} | |||||
func (rows *binaryRows) Next(dest []driver.Value) error { | |||||
if mc := rows.mc; mc != nil { | if mc := rows.mc; mc != nil { | ||||
if mc.netConn == nil { | |||||
return ErrInvalidConn | |||||
if err := mc.error(); err != nil { | |||||
return err | |||||
} | } | ||||
// Fetch next row from stream | // Fetch next row from stream | ||||
@@ -99,14 +193,24 @@ func (rows *textRows) Next(dest []driver.Value) error { | |||||
return io.EOF | return io.EOF | ||||
} | } | ||||
func (rows emptyRows) Columns() []string { | |||||
return nil | |||||
} | |||||
func (rows *textRows) NextResultSet() (err error) { | |||||
resLen, err := rows.nextNotEmptyResultSet() | |||||
if err != nil { | |||||
return err | |||||
} | |||||
func (rows emptyRows) Close() error { | |||||
return nil | |||||
rows.rs.columns, err = rows.mc.readColumns(resLen) | |||||
return err | |||||
} | } | ||||
func (rows emptyRows) Next(dest []driver.Value) error { | |||||
func (rows *textRows) Next(dest []driver.Value) error { | |||||
if mc := rows.mc; mc != nil { | |||||
if err := mc.error(); err != nil { | |||||
return err | |||||
} | |||||
// Fetch next row from stream | |||||
return rows.readRow(dest) | |||||
} | |||||
return io.EOF | return io.EOF | ||||
} | } |
@@ -11,6 +11,7 @@ package mysql | |||||
import ( | import ( | ||||
"database/sql/driver" | "database/sql/driver" | ||||
"fmt" | "fmt" | ||||
"io" | |||||
"reflect" | "reflect" | ||||
"strconv" | "strconv" | ||||
) | ) | ||||
@@ -19,12 +20,14 @@ type mysqlStmt struct { | |||||
mc *mysqlConn | mc *mysqlConn | ||||
id uint32 | id uint32 | ||||
paramCount int | paramCount int | ||||
columns []mysqlField // cached from the first query | |||||
} | } | ||||
func (stmt *mysqlStmt) Close() error { | func (stmt *mysqlStmt) Close() error { | ||||
if stmt.mc == nil || stmt.mc.netConn == nil { | |||||
errLog.Print(ErrInvalidConn) | |||||
if stmt.mc == nil || stmt.mc.closed.IsSet() { | |||||
// driver.Stmt.Close can be called more than once, thus this function | |||||
// has to be idempotent. | |||||
// See also Issue #450 and golang/go#16019. | |||||
//errLog.Print(ErrInvalidConn) | |||||
return driver.ErrBadConn | return driver.ErrBadConn | ||||
} | } | ||||
@@ -42,14 +45,14 @@ func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter { | |||||
} | } | ||||
func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { | func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { | ||||
if stmt.mc.netConn == nil { | |||||
if stmt.mc.closed.IsSet() { | |||||
errLog.Print(ErrInvalidConn) | errLog.Print(ErrInvalidConn) | ||||
return nil, driver.ErrBadConn | return nil, driver.ErrBadConn | ||||
} | } | ||||
// Send command | // Send command | ||||
err := stmt.writeExecutePacket(args) | err := stmt.writeExecutePacket(args) | ||||
if err != nil { | if err != nil { | ||||
return nil, err | |||||
return nil, stmt.mc.markBadConn(err) | |||||
} | } | ||||
mc := stmt.mc | mc := stmt.mc | ||||
@@ -59,37 +62,45 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { | |||||
// Read Result | // Read Result | ||||
resLen, err := mc.readResultSetHeaderPacket() | resLen, err := mc.readResultSetHeaderPacket() | ||||
if err == nil { | |||||
if resLen > 0 { | |||||
// Columns | |||||
err = mc.readUntilEOF() | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
// Rows | |||||
err = mc.readUntilEOF() | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
if resLen > 0 { | |||||
// Columns | |||||
if err = mc.readUntilEOF(); err != nil { | |||||
return nil, err | |||||
} | } | ||||
if err == nil { | |||||
return &mysqlResult{ | |||||
affectedRows: int64(mc.affectedRows), | |||||
insertId: int64(mc.insertId), | |||||
}, nil | |||||
// Rows | |||||
if err := mc.readUntilEOF(); err != nil { | |||||
return nil, err | |||||
} | } | ||||
} | } | ||||
return nil, err | |||||
if err := mc.discardResults(); err != nil { | |||||
return nil, err | |||||
} | |||||
return &mysqlResult{ | |||||
affectedRows: int64(mc.affectedRows), | |||||
insertId: int64(mc.insertId), | |||||
}, nil | |||||
} | } | ||||
func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { | func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { | ||||
if stmt.mc.netConn == nil { | |||||
return stmt.query(args) | |||||
} | |||||
func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { | |||||
if stmt.mc.closed.IsSet() { | |||||
errLog.Print(ErrInvalidConn) | errLog.Print(ErrInvalidConn) | ||||
return nil, driver.ErrBadConn | return nil, driver.ErrBadConn | ||||
} | } | ||||
// Send command | // Send command | ||||
err := stmt.writeExecutePacket(args) | err := stmt.writeExecutePacket(args) | ||||
if err != nil { | if err != nil { | ||||
return nil, err | |||||
return nil, stmt.mc.markBadConn(err) | |||||
} | } | ||||
mc := stmt.mc | mc := stmt.mc | ||||
@@ -104,14 +115,15 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { | |||||
if resLen > 0 { | if resLen > 0 { | ||||
rows.mc = mc | rows.mc = mc | ||||
// Columns | |||||
// If not cached, read them and cache them | |||||
if stmt.columns == nil { | |||||
rows.columns, err = mc.readColumns(resLen) | |||||
stmt.columns = rows.columns | |||||
} else { | |||||
rows.columns = stmt.columns | |||||
err = mc.readUntilEOF() | |||||
rows.rs.columns, err = mc.readColumns(resLen) | |||||
} else { | |||||
rows.rs.done = true | |||||
switch err := rows.NextResultSet(); err { | |||||
case nil, io.EOF: | |||||
return rows, nil | |||||
default: | |||||
return nil, err | |||||
} | } | ||||
} | } | ||||
@@ -120,19 +132,36 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { | |||||
type converter struct{} | type converter struct{} | ||||
// ConvertValue mirrors the reference/default converter in database/sql/driver | |||||
// with _one_ exception. We support uint64 with their high bit and the default | |||||
// implementation does not. This function should be kept in sync with | |||||
// database/sql/driver defaultConverter.ConvertValue() except for that | |||||
// deliberate difference. | |||||
func (c converter) ConvertValue(v interface{}) (driver.Value, error) { | func (c converter) ConvertValue(v interface{}) (driver.Value, error) { | ||||
if driver.IsValue(v) { | if driver.IsValue(v) { | ||||
return v, nil | return v, nil | ||||
} | } | ||||
if vr, ok := v.(driver.Valuer); ok { | |||||
sv, err := callValuerValue(vr) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
if !driver.IsValue(sv) { | |||||
return nil, fmt.Errorf("non-Value type %T returned from Value", sv) | |||||
} | |||||
return sv, nil | |||||
} | |||||
rv := reflect.ValueOf(v) | rv := reflect.ValueOf(v) | ||||
switch rv.Kind() { | switch rv.Kind() { | ||||
case reflect.Ptr: | case reflect.Ptr: | ||||
// indirect pointers | // indirect pointers | ||||
if rv.IsNil() { | if rv.IsNil() { | ||||
return nil, nil | return nil, nil | ||||
} else { | |||||
return c.ConvertValue(rv.Elem().Interface()) | |||||
} | } | ||||
return c.ConvertValue(rv.Elem().Interface()) | |||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | ||||
return rv.Int(), nil | return rv.Int(), nil | ||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: | ||||
@@ -145,6 +174,38 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) { | |||||
return int64(u64), nil | return int64(u64), nil | ||||
case reflect.Float32, reflect.Float64: | case reflect.Float32, reflect.Float64: | ||||
return rv.Float(), nil | return rv.Float(), nil | ||||
case reflect.Bool: | |||||
return rv.Bool(), nil | |||||
case reflect.Slice: | |||||
ek := rv.Type().Elem().Kind() | |||||
if ek == reflect.Uint8 { | |||||
return rv.Bytes(), nil | |||||
} | |||||
return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek) | |||||
case reflect.String: | |||||
return rv.String(), nil | |||||
} | } | ||||
return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind()) | return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind()) | ||||
} | } | ||||
var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() | |||||
// callValuerValue returns vr.Value(), with one exception: | |||||
// If vr.Value is an auto-generated method on a pointer type and the | |||||
// pointer is nil, it would panic at runtime in the panicwrap | |||||
// method. Treat it like nil instead. | |||||
// | |||||
// This is so people can implement driver.Value on value types and | |||||
// still use nil pointers to those types to mean nil/NULL, just like | |||||
// string/*string. | |||||
// | |||||
// This is an exact copy of the same-named unexported function from the | |||||
// database/sql package. | |||||
func callValuerValue(vr driver.Valuer) (v driver.Value, err error) { | |||||
if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr && | |||||
rv.IsNil() && | |||||
rv.Type().Elem().Implements(valuerReflectType) { | |||||
return nil, nil | |||||
} | |||||
return vr.Value() | |||||
} |
@@ -13,7 +13,7 @@ type mysqlTx struct { | |||||
} | } | ||||
func (tx *mysqlTx) Commit() (err error) { | func (tx *mysqlTx) Commit() (err error) { | ||||
if tx.mc == nil || tx.mc.netConn == nil { | |||||
if tx.mc == nil || tx.mc.closed.IsSet() { | |||||
return ErrInvalidConn | return ErrInvalidConn | ||||
} | } | ||||
err = tx.mc.exec("COMMIT") | err = tx.mc.exec("COMMIT") | ||||
@@ -22,7 +22,7 @@ func (tx *mysqlTx) Commit() (err error) { | |||||
} | } | ||||
func (tx *mysqlTx) Rollback() (err error) { | func (tx *mysqlTx) Rollback() (err error) { | ||||
if tx.mc == nil || tx.mc.netConn == nil { | |||||
if tx.mc == nil || tx.mc.closed.IsSet() { | |||||
return ErrInvalidConn | return ErrInvalidConn | ||||
} | } | ||||
err = tx.mc.exec("ROLLBACK") | err = tx.mc.exec("ROLLBACK") | ||||
@@ -9,23 +9,29 @@ | |||||
package mysql | package mysql | ||||
import ( | import ( | ||||
"crypto/sha1" | |||||
"crypto/tls" | "crypto/tls" | ||||
"database/sql/driver" | "database/sql/driver" | ||||
"encoding/binary" | "encoding/binary" | ||||
"fmt" | "fmt" | ||||
"io" | "io" | ||||
"strings" | "strings" | ||||
"sync" | |||||
"sync/atomic" | |||||
"time" | "time" | ||||
) | ) | ||||
// Registry for custom tls.Configs | |||||
var ( | var ( | ||||
tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs | |||||
tlsConfigLock sync.RWMutex | |||||
tlsConfigRegistry map[string]*tls.Config | |||||
) | ) | ||||
// RegisterTLSConfig registers a custom tls.Config to be used with sql.Open. | // RegisterTLSConfig registers a custom tls.Config to be used with sql.Open. | ||||
// Use the key as a value in the DSN where tls=value. | // Use the key as a value in the DSN where tls=value. | ||||
// | // | ||||
// Note: The provided tls.Config is exclusively owned by the driver after | |||||
// registering it. | |||||
// | |||||
// rootCertPool := x509.NewCertPool() | // rootCertPool := x509.NewCertPool() | ||||
// pem, err := ioutil.ReadFile("/path/ca-cert.pem") | // pem, err := ioutil.ReadFile("/path/ca-cert.pem") | ||||
// if err != nil { | // if err != nil { | ||||
@@ -51,19 +57,32 @@ func RegisterTLSConfig(key string, config *tls.Config) error { | |||||
return fmt.Errorf("key '%s' is reserved", key) | return fmt.Errorf("key '%s' is reserved", key) | ||||
} | } | ||||
if tlsConfigRegister == nil { | |||||
tlsConfigRegister = make(map[string]*tls.Config) | |||||
tlsConfigLock.Lock() | |||||
if tlsConfigRegistry == nil { | |||||
tlsConfigRegistry = make(map[string]*tls.Config) | |||||
} | } | ||||
tlsConfigRegister[key] = config | |||||
tlsConfigRegistry[key] = config | |||||
tlsConfigLock.Unlock() | |||||
return nil | return nil | ||||
} | } | ||||
// DeregisterTLSConfig removes the tls.Config associated with key. | // DeregisterTLSConfig removes the tls.Config associated with key. | ||||
func DeregisterTLSConfig(key string) { | func DeregisterTLSConfig(key string) { | ||||
if tlsConfigRegister != nil { | |||||
delete(tlsConfigRegister, key) | |||||
tlsConfigLock.Lock() | |||||
if tlsConfigRegistry != nil { | |||||
delete(tlsConfigRegistry, key) | |||||
} | } | ||||
tlsConfigLock.Unlock() | |||||
} | |||||
func getTLSConfigClone(key string) (config *tls.Config) { | |||||
tlsConfigLock.RLock() | |||||
if v, ok := tlsConfigRegistry[key]; ok { | |||||
config = cloneTLSConfig(v) | |||||
} | |||||
tlsConfigLock.RUnlock() | |||||
return | |||||
} | } | ||||
// Returns the bool value of the input. | // Returns the bool value of the input. | ||||
@@ -81,119 +100,6 @@ func readBool(input string) (value bool, valid bool) { | |||||
} | } | ||||
/****************************************************************************** | /****************************************************************************** | ||||
* Authentication * | |||||
******************************************************************************/ | |||||
// Encrypt password using 4.1+ method | |||||
func scramblePassword(scramble, password []byte) []byte { | |||||
if len(password) == 0 { | |||||
return nil | |||||
} | |||||
// stage1Hash = SHA1(password) | |||||
crypt := sha1.New() | |||||
crypt.Write(password) | |||||
stage1 := crypt.Sum(nil) | |||||
// scrambleHash = SHA1(scramble + SHA1(stage1Hash)) | |||||
// inner Hash | |||||
crypt.Reset() | |||||
crypt.Write(stage1) | |||||
hash := crypt.Sum(nil) | |||||
// outer Hash | |||||
crypt.Reset() | |||||
crypt.Write(scramble) | |||||
crypt.Write(hash) | |||||
scramble = crypt.Sum(nil) | |||||
// token = scrambleHash XOR stage1Hash | |||||
for i := range scramble { | |||||
scramble[i] ^= stage1[i] | |||||
} | |||||
return scramble | |||||
} | |||||
// Encrypt password using pre 4.1 (old password) method | |||||
// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c | |||||
type myRnd struct { | |||||
seed1, seed2 uint32 | |||||
} | |||||
const myRndMaxVal = 0x3FFFFFFF | |||||
// Pseudo random number generator | |||||
func newMyRnd(seed1, seed2 uint32) *myRnd { | |||||
return &myRnd{ | |||||
seed1: seed1 % myRndMaxVal, | |||||
seed2: seed2 % myRndMaxVal, | |||||
} | |||||
} | |||||
// Tested to be equivalent to MariaDB's floating point variant | |||||
// http://play.golang.org/p/QHvhd4qved | |||||
// http://play.golang.org/p/RG0q4ElWDx | |||||
func (r *myRnd) NextByte() byte { | |||||
r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal | |||||
r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal | |||||
return byte(uint64(r.seed1) * 31 / myRndMaxVal) | |||||
} | |||||
// Generate binary hash from byte string using insecure pre 4.1 method | |||||
func pwHash(password []byte) (result [2]uint32) { | |||||
var add uint32 = 7 | |||||
var tmp uint32 | |||||
result[0] = 1345345333 | |||||
result[1] = 0x12345671 | |||||
for _, c := range password { | |||||
// skip spaces and tabs in password | |||||
if c == ' ' || c == '\t' { | |||||
continue | |||||
} | |||||
tmp = uint32(c) | |||||
result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8) | |||||
result[1] += (result[1] << 8) ^ result[0] | |||||
add += tmp | |||||
} | |||||
// Remove sign bit (1<<31)-1) | |||||
result[0] &= 0x7FFFFFFF | |||||
result[1] &= 0x7FFFFFFF | |||||
return | |||||
} | |||||
// Encrypt password using insecure pre 4.1 method | |||||
func scrambleOldPassword(scramble, password []byte) []byte { | |||||
if len(password) == 0 { | |||||
return nil | |||||
} | |||||
scramble = scramble[:8] | |||||
hashPw := pwHash(password) | |||||
hashSc := pwHash(scramble) | |||||
r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1]) | |||||
var out [8]byte | |||||
for i := range out { | |||||
out[i] = r.NextByte() + 64 | |||||
} | |||||
mask := r.NextByte() | |||||
for i := range out { | |||||
out[i] ^= mask | |||||
} | |||||
return out[:] | |||||
} | |||||
/****************************************************************************** | |||||
* Time related utils * | * Time related utils * | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
@@ -519,7 +425,7 @@ func readLengthEncodedString(b []byte) ([]byte, bool, int, error) { | |||||
// Check data length | // Check data length | ||||
if len(b) >= n { | if len(b) >= n { | ||||
return b[n-int(num) : n], false, n, nil | |||||
return b[n-int(num) : n : n], false, n, nil | |||||
} | } | ||||
return nil, false, n, io.EOF | return nil, false, n, io.EOF | ||||
} | } | ||||
@@ -548,8 +454,8 @@ func readLengthEncodedInteger(b []byte) (uint64, bool, int) { | |||||
if len(b) == 0 { | if len(b) == 0 { | ||||
return 0, true, 1 | return 0, true, 1 | ||||
} | } | ||||
switch b[0] { | |||||
switch b[0] { | |||||
// 251: NULL | // 251: NULL | ||||
case 0xfb: | case 0xfb: | ||||
return 0, true, 1 | return 0, true, 1 | ||||
@@ -738,3 +644,67 @@ func escapeStringQuotes(buf []byte, v string) []byte { | |||||
return buf[:pos] | return buf[:pos] | ||||
} | } | ||||
/****************************************************************************** | |||||
* Sync utils * | |||||
******************************************************************************/ | |||||
// noCopy may be embedded into structs which must not be copied | |||||
// after the first use. | |||||
// | |||||
// See https://github.com/golang/go/issues/8005#issuecomment-190753527 | |||||
// for details. | |||||
type noCopy struct{} | |||||
// Lock is a no-op used by -copylocks checker from `go vet`. | |||||
func (*noCopy) Lock() {} | |||||
// atomicBool is a wrapper around uint32 for usage as a boolean value with | |||||
// atomic access. | |||||
type atomicBool struct { | |||||
_noCopy noCopy | |||||
value uint32 | |||||
} | |||||
// IsSet returns wether the current boolean value is true | |||||
func (ab *atomicBool) IsSet() bool { | |||||
return atomic.LoadUint32(&ab.value) > 0 | |||||
} | |||||
// Set sets the value of the bool regardless of the previous value | |||||
func (ab *atomicBool) Set(value bool) { | |||||
if value { | |||||
atomic.StoreUint32(&ab.value, 1) | |||||
} else { | |||||
atomic.StoreUint32(&ab.value, 0) | |||||
} | |||||
} | |||||
// TrySet sets the value of the bool and returns wether the value changed | |||||
func (ab *atomicBool) TrySet(value bool) bool { | |||||
if value { | |||||
return atomic.SwapUint32(&ab.value, 1) == 0 | |||||
} | |||||
return atomic.SwapUint32(&ab.value, 0) > 0 | |||||
} | |||||
// atomicError is a wrapper for atomically accessed error values | |||||
type atomicError struct { | |||||
_noCopy noCopy | |||||
value atomic.Value | |||||
} | |||||
// Set sets the error value regardless of the previous value. | |||||
// The value must not be nil | |||||
func (ae *atomicError) Set(value error) { | |||||
ae.value.Store(value) | |||||
} | |||||
// Value returns the current error value | |||||
func (ae *atomicError) Value() error { | |||||
if v := ae.value.Load(); v != nil { | |||||
// this will panic if the value doesn't implement the error interface | |||||
return v.(error) | |||||
} | |||||
return nil | |||||
} |
@@ -0,0 +1,40 @@ | |||||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package | |||||
// | |||||
// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. | |||||
// | |||||
// This Source Code Form is subject to the terms of the Mozilla Public | |||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file, | |||||
// You can obtain one at http://mozilla.org/MPL/2.0/. | |||||
// +build go1.7 | |||||
// +build !go1.8 | |||||
package mysql | |||||
import "crypto/tls" | |||||
func cloneTLSConfig(c *tls.Config) *tls.Config { | |||||
return &tls.Config{ | |||||
Rand: c.Rand, | |||||
Time: c.Time, | |||||
Certificates: c.Certificates, | |||||
NameToCertificate: c.NameToCertificate, | |||||
GetCertificate: c.GetCertificate, | |||||
RootCAs: c.RootCAs, | |||||
NextProtos: c.NextProtos, | |||||
ServerName: c.ServerName, | |||||
ClientAuth: c.ClientAuth, | |||||
ClientCAs: c.ClientCAs, | |||||
InsecureSkipVerify: c.InsecureSkipVerify, | |||||
CipherSuites: c.CipherSuites, | |||||
PreferServerCipherSuites: c.PreferServerCipherSuites, | |||||
SessionTicketsDisabled: c.SessionTicketsDisabled, | |||||
SessionTicketKey: c.SessionTicketKey, | |||||
ClientSessionCache: c.ClientSessionCache, | |||||
MinVersion: c.MinVersion, | |||||
MaxVersion: c.MaxVersion, | |||||
CurvePreferences: c.CurvePreferences, | |||||
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, | |||||
Renegotiation: c.Renegotiation, | |||||
} | |||||
} |
@@ -0,0 +1,50 @@ | |||||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package | |||||
// | |||||
// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. | |||||
// | |||||
// This Source Code Form is subject to the terms of the Mozilla Public | |||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file, | |||||
// You can obtain one at http://mozilla.org/MPL/2.0/. | |||||
// +build go1.8 | |||||
package mysql | |||||
import ( | |||||
"crypto/tls" | |||||
"database/sql" | |||||
"database/sql/driver" | |||||
"errors" | |||||
"fmt" | |||||
) | |||||
func cloneTLSConfig(c *tls.Config) *tls.Config { | |||||
return c.Clone() | |||||
} | |||||
func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { | |||||
dargs := make([]driver.Value, len(named)) | |||||
for n, param := range named { | |||||
if len(param.Name) > 0 { | |||||
// TODO: support the use of Named Parameters #561 | |||||
return nil, errors.New("mysql: driver does not support the use of Named Parameters") | |||||
} | |||||
dargs[n] = param.Value | |||||
} | |||||
return dargs, nil | |||||
} | |||||
func mapIsolationLevel(level driver.IsolationLevel) (string, error) { | |||||
switch sql.IsolationLevel(level) { | |||||
case sql.LevelRepeatableRead: | |||||
return "REPEATABLE READ", nil | |||||
case sql.LevelReadCommitted: | |||||
return "READ COMMITTED", nil | |||||
case sql.LevelReadUncommitted: | |||||
return "READ UNCOMMITTED", nil | |||||
case sql.LevelSerializable: | |||||
return "SERIALIZABLE", nil | |||||
default: | |||||
return "", fmt.Errorf("mysql: unsupported isolation level: %v", level) | |||||
} | |||||
} |