@@ -294,7 +294,7 @@ | |||
[[projects]] | |||
name = "github.com/go-sql-driver/mysql" | |||
packages = ["."] | |||
revision = "ce924a41eea897745442daaa1739089b0f3f561d" | |||
revision = "d523deb1b23d913de5bdada721a6071e71283618" | |||
[[projects]] | |||
name = "github.com/go-xorm/builder" | |||
@@ -873,6 +873,6 @@ | |||
[solve-meta] | |||
analyzer-name = "dep" | |||
analyzer-version = 1 | |||
inputs-digest = "036b8c882671cf8d2c5e2fdbe53b1bdfbd39f7ebd7765bd50276c7c4ecf16687" | |||
inputs-digest = "96c83a3502bd50c5ca8e4d9b4145172267630270e587c79b7253156725eeb9b8" | |||
solver-name = "gps-cdcl" | |||
solver-version = 1 |
@@ -41,6 +41,10 @@ ignored = ["google.golang.org/appengine*"] | |||
revision = "d4149d1eee0c2c488a74a5863fd9caf13d60fd03" | |||
[[override]] | |||
name = "github.com/go-sql-driver/mysql" | |||
revision = "d523deb1b23d913de5bdada721a6071e71283618" | |||
[[override]] | |||
name = "github.com/gorilla/mux" | |||
revision = "757bef944d0f21880861c2dd9c871ca543023cba" | |||
@@ -12,34 +12,63 @@ | |||
# Individual Persons | |||
Aaron Hopkins <go-sql-driver at die.net> | |||
Achille Roussel <achille.roussel at gmail.com> | |||
Alexey Palazhchenko <alexey.palazhchenko at gmail.com> | |||
Andrew Reid <andrew.reid at tixtrack.com> | |||
Arne Hormann <arnehormann at gmail.com> | |||
Asta Xie <xiemengjun at gmail.com> | |||
Bulat Gaifullin <gaifullinbf at gmail.com> | |||
Carlos Nieto <jose.carlos at menteslibres.net> | |||
Chris Moos <chris at tech9computers.com> | |||
Craig Wilson <craiggwilson at gmail.com> | |||
Daniel Montoya <dsmontoyam at gmail.com> | |||
Daniel Nichter <nil at codenode.com> | |||
Daniël van Eeden <git at myname.nl> | |||
Dave Protasowski <dprotaso at gmail.com> | |||
DisposaBoy <disposaboy at dby.me> | |||
Egor Smolyakov <egorsmkv at gmail.com> | |||
Evan Shaw <evan at vendhq.com> | |||
Frederick Mayle <frederickmayle at gmail.com> | |||
Gustavo Kristic <gkristic at gmail.com> | |||
Hajime Nakagami <nakagami at gmail.com> | |||
Hanno Braun <mail at hannobraun.com> | |||
Henri Yandell <flamefew at gmail.com> | |||
Hirotaka Yamamoto <ymmt2005 at gmail.com> | |||
ICHINOSE Shogo <shogo82148 at gmail.com> | |||
INADA Naoki <songofacandy at gmail.com> | |||
Jacek Szwec <szwec.jacek at gmail.com> | |||
James Harr <james.harr at gmail.com> | |||
Jeff Hodges <jeff at somethingsimilar.com> | |||
Jeffrey Charles <jeffreycharles at gmail.com> | |||
Jian Zhen <zhenjl at gmail.com> | |||
Joshua Prunier <joshua.prunier at gmail.com> | |||
Julien Lefevre <julien.lefevr at gmail.com> | |||
Julien Schmidt <go-sql-driver at julienschmidt.com> | |||
Justin Li <jli at j-li.net> | |||
Justin Nuß <nuss.justin at gmail.com> | |||
Kamil Dziedzic <kamil at klecza.pl> | |||
Kevin Malachowski <kevin at chowski.com> | |||
Kieron Woodhouse <kieron.woodhouse at infosum.com> | |||
Lennart Rudolph <lrudolph at hmc.edu> | |||
Leonardo YongUk Kim <dalinaum at gmail.com> | |||
Linh Tran Tuan <linhduonggnu at gmail.com> | |||
Lion Yang <lion at aosc.xyz> | |||
Luca Looz <luca.looz92 at gmail.com> | |||
Lucas Liu <extrafliu at gmail.com> | |||
Luke Scott <luke at webconnex.com> | |||
Maciej Zimnoch <maciej.zimnoch at codilime.com> | |||
Michael Woolnough <michael.woolnough at gmail.com> | |||
Nicola Peduzzi <thenikso at gmail.com> | |||
Olivier Mengué <dolmen at cpan.org> | |||
oscarzhao <oscarzhaosl at gmail.com> | |||
Paul Bonser <misterpib at gmail.com> | |||
Peter Schultz <peter.schultz at classmarkets.com> | |||
Rebecca Chin <rchin at pivotal.io> | |||
Reed Allman <rdallman10 at gmail.com> | |||
Richard Wilkes <wilkes at me.com> | |||
Robert Russell <robert at rrbrussell.com> | |||
Runrioter Wung <runrioter at gmail.com> | |||
Shuode Li <elemount at qq.com> | |||
Soroush Pour <me at soroushjp.com> | |||
Stan Putrya <root.vagner at gmail.com> | |||
Stanley Gunawan <gunawan.stanley at gmail.com> | |||
@@ -51,5 +80,10 @@ Zhenye Xie <xiezhenye at gmail.com> | |||
# Organizations | |||
Barracuda Networks, Inc. | |||
Counting Ltd. | |||
Google Inc. | |||
InfoSum Ltd. | |||
Keybase Inc. | |||
Percona LLC | |||
Pivotal Inc. | |||
Stripe Inc. |
@@ -11,7 +11,7 @@ | |||
package mysql | |||
import ( | |||
"appengine/cloudsql" | |||
"google.golang.org/appengine/cloudsql" | |||
) | |||
func init() { | |||
@@ -0,0 +1,420 @@ | |||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package | |||
// | |||
// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved. | |||
// | |||
// This Source Code Form is subject to the terms of the Mozilla Public | |||
// License, v. 2.0. If a copy of the MPL was not distributed with this file, | |||
// You can obtain one at http://mozilla.org/MPL/2.0/. | |||
package mysql | |||
import ( | |||
"crypto/rand" | |||
"crypto/rsa" | |||
"crypto/sha1" | |||
"crypto/sha256" | |||
"crypto/x509" | |||
"encoding/pem" | |||
"sync" | |||
) | |||
// server pub keys registry | |||
var ( | |||
serverPubKeyLock sync.RWMutex | |||
serverPubKeyRegistry map[string]*rsa.PublicKey | |||
) | |||
// RegisterServerPubKey registers a server RSA public key which can be used to | |||
// send data in a secure manner to the server without receiving the public key | |||
// in a potentially insecure way from the server first. | |||
// Registered keys can afterwards be used adding serverPubKey=<name> to the DSN. | |||
// | |||
// Note: The provided rsa.PublicKey instance is exclusively owned by the driver | |||
// after registering it and may not be modified. | |||
// | |||
// data, err := ioutil.ReadFile("mykey.pem") | |||
// if err != nil { | |||
// log.Fatal(err) | |||
// } | |||
// | |||
// block, _ := pem.Decode(data) | |||
// if block == nil || block.Type != "PUBLIC KEY" { | |||
// log.Fatal("failed to decode PEM block containing public key") | |||
// } | |||
// | |||
// pub, err := x509.ParsePKIXPublicKey(block.Bytes) | |||
// if err != nil { | |||
// log.Fatal(err) | |||
// } | |||
// | |||
// if rsaPubKey, ok := pub.(*rsa.PublicKey); ok { | |||
// mysql.RegisterServerPubKey("mykey", rsaPubKey) | |||
// } else { | |||
// log.Fatal("not a RSA public key") | |||
// } | |||
// | |||
func RegisterServerPubKey(name string, pubKey *rsa.PublicKey) { | |||
serverPubKeyLock.Lock() | |||
if serverPubKeyRegistry == nil { | |||
serverPubKeyRegistry = make(map[string]*rsa.PublicKey) | |||
} | |||
serverPubKeyRegistry[name] = pubKey | |||
serverPubKeyLock.Unlock() | |||
} | |||
// DeregisterServerPubKey removes the public key registered with the given name. | |||
func DeregisterServerPubKey(name string) { | |||
serverPubKeyLock.Lock() | |||
if serverPubKeyRegistry != nil { | |||
delete(serverPubKeyRegistry, name) | |||
} | |||
serverPubKeyLock.Unlock() | |||
} | |||
func getServerPubKey(name string) (pubKey *rsa.PublicKey) { | |||
serverPubKeyLock.RLock() | |||
if v, ok := serverPubKeyRegistry[name]; ok { | |||
pubKey = v | |||
} | |||
serverPubKeyLock.RUnlock() | |||
return | |||
} | |||
// Hash password using pre 4.1 (old password) method | |||
// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c | |||
type myRnd struct { | |||
seed1, seed2 uint32 | |||
} | |||
const myRndMaxVal = 0x3FFFFFFF | |||
// Pseudo random number generator | |||
func newMyRnd(seed1, seed2 uint32) *myRnd { | |||
return &myRnd{ | |||
seed1: seed1 % myRndMaxVal, | |||
seed2: seed2 % myRndMaxVal, | |||
} | |||
} | |||
// Tested to be equivalent to MariaDB's floating point variant | |||
// http://play.golang.org/p/QHvhd4qved | |||
// http://play.golang.org/p/RG0q4ElWDx | |||
func (r *myRnd) NextByte() byte { | |||
r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal | |||
r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal | |||
return byte(uint64(r.seed1) * 31 / myRndMaxVal) | |||
} | |||
// Generate binary hash from byte string using insecure pre 4.1 method | |||
func pwHash(password []byte) (result [2]uint32) { | |||
var add uint32 = 7 | |||
var tmp uint32 | |||
result[0] = 1345345333 | |||
result[1] = 0x12345671 | |||
for _, c := range password { | |||
// skip spaces and tabs in password | |||
if c == ' ' || c == '\t' { | |||
continue | |||
} | |||
tmp = uint32(c) | |||
result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8) | |||
result[1] += (result[1] << 8) ^ result[0] | |||
add += tmp | |||
} | |||
// Remove sign bit (1<<31)-1) | |||
result[0] &= 0x7FFFFFFF | |||
result[1] &= 0x7FFFFFFF | |||
return | |||
} | |||
// Hash password using insecure pre 4.1 method | |||
func scrambleOldPassword(scramble []byte, password string) []byte { | |||
if len(password) == 0 { | |||
return nil | |||
} | |||
scramble = scramble[:8] | |||
hashPw := pwHash([]byte(password)) | |||
hashSc := pwHash(scramble) | |||
r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1]) | |||
var out [8]byte | |||
for i := range out { | |||
out[i] = r.NextByte() + 64 | |||
} | |||
mask := r.NextByte() | |||
for i := range out { | |||
out[i] ^= mask | |||
} | |||
return out[:] | |||
} | |||
// Hash password using 4.1+ method (SHA1) | |||
func scramblePassword(scramble []byte, password string) []byte { | |||
if len(password) == 0 { | |||
return nil | |||
} | |||
// stage1Hash = SHA1(password) | |||
crypt := sha1.New() | |||
crypt.Write([]byte(password)) | |||
stage1 := crypt.Sum(nil) | |||
// scrambleHash = SHA1(scramble + SHA1(stage1Hash)) | |||
// inner Hash | |||
crypt.Reset() | |||
crypt.Write(stage1) | |||
hash := crypt.Sum(nil) | |||
// outer Hash | |||
crypt.Reset() | |||
crypt.Write(scramble) | |||
crypt.Write(hash) | |||
scramble = crypt.Sum(nil) | |||
// token = scrambleHash XOR stage1Hash | |||
for i := range scramble { | |||
scramble[i] ^= stage1[i] | |||
} | |||
return scramble | |||
} | |||
// Hash password using MySQL 8+ method (SHA256) | |||
func scrambleSHA256Password(scramble []byte, password string) []byte { | |||
if len(password) == 0 { | |||
return nil | |||
} | |||
// XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble)) | |||
crypt := sha256.New() | |||
crypt.Write([]byte(password)) | |||
message1 := crypt.Sum(nil) | |||
crypt.Reset() | |||
crypt.Write(message1) | |||
message1Hash := crypt.Sum(nil) | |||
crypt.Reset() | |||
crypt.Write(message1Hash) | |||
crypt.Write(scramble) | |||
message2 := crypt.Sum(nil) | |||
for i := range message1 { | |||
message1[i] ^= message2[i] | |||
} | |||
return message1 | |||
} | |||
func encryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, error) { | |||
plain := make([]byte, len(password)+1) | |||
copy(plain, password) | |||
for i := range plain { | |||
j := i % len(seed) | |||
plain[i] ^= seed[j] | |||
} | |||
sha1 := sha1.New() | |||
return rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil) | |||
} | |||
func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) error { | |||
enc, err := encryptPassword(mc.cfg.Passwd, seed, pub) | |||
if err != nil { | |||
return err | |||
} | |||
return mc.writeAuthSwitchPacket(enc, false) | |||
} | |||
func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, bool, error) { | |||
switch plugin { | |||
case "caching_sha2_password": | |||
authResp := scrambleSHA256Password(authData, mc.cfg.Passwd) | |||
return authResp, (authResp == nil), nil | |||
case "mysql_old_password": | |||
if !mc.cfg.AllowOldPasswords { | |||
return nil, false, ErrOldPassword | |||
} | |||
// Note: there are edge cases where this should work but doesn't; | |||
// this is currently "wontfix": | |||
// https://github.com/go-sql-driver/mysql/issues/184 | |||
authResp := scrambleOldPassword(authData[:8], mc.cfg.Passwd) | |||
return authResp, true, nil | |||
case "mysql_clear_password": | |||
if !mc.cfg.AllowCleartextPasswords { | |||
return nil, false, ErrCleartextPassword | |||
} | |||
// http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html | |||
// http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html | |||
return []byte(mc.cfg.Passwd), true, nil | |||
case "mysql_native_password": | |||
if !mc.cfg.AllowNativePasswords { | |||
return nil, false, ErrNativePassword | |||
} | |||
// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html | |||
// Native password authentication only need and will need 20-byte challenge. | |||
authResp := scramblePassword(authData[:20], mc.cfg.Passwd) | |||
return authResp, false, nil | |||
case "sha256_password": | |||
if len(mc.cfg.Passwd) == 0 { | |||
return nil, true, nil | |||
} | |||
if mc.cfg.tls != nil || mc.cfg.Net == "unix" { | |||
// write cleartext auth packet | |||
return []byte(mc.cfg.Passwd), true, nil | |||
} | |||
pubKey := mc.cfg.pubKey | |||
if pubKey == nil { | |||
// request public key from server | |||
return []byte{1}, false, nil | |||
} | |||
// encrypted password | |||
enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey) | |||
return enc, false, err | |||
default: | |||
errLog.Print("unknown auth plugin:", plugin) | |||
return nil, false, ErrUnknownPlugin | |||
} | |||
} | |||
func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { | |||
// Read Result Packet | |||
authData, newPlugin, err := mc.readAuthResult() | |||
if err != nil { | |||
return err | |||
} | |||
// handle auth plugin switch, if requested | |||
if newPlugin != "" { | |||
// If CLIENT_PLUGIN_AUTH capability is not supported, no new cipher is | |||
// sent and we have to keep using the cipher sent in the init packet. | |||
if authData == nil { | |||
authData = oldAuthData | |||
} else { | |||
// copy data from read buffer to owned slice | |||
copy(oldAuthData, authData) | |||
} | |||
plugin = newPlugin | |||
authResp, addNUL, err := mc.auth(authData, plugin) | |||
if err != nil { | |||
return err | |||
} | |||
if err = mc.writeAuthSwitchPacket(authResp, addNUL); err != nil { | |||
return err | |||
} | |||
// Read Result Packet | |||
authData, newPlugin, err = mc.readAuthResult() | |||
if err != nil { | |||
return err | |||
} | |||
// Do not allow to change the auth plugin more than once | |||
if newPlugin != "" { | |||
return ErrMalformPkt | |||
} | |||
} | |||
switch plugin { | |||
// https://insidemysql.com/preparing-your-community-connector-for-mysql-8-part-2-sha256/ | |||
case "caching_sha2_password": | |||
switch len(authData) { | |||
case 0: | |||
return nil // auth successful | |||
case 1: | |||
switch authData[0] { | |||
case cachingSha2PasswordFastAuthSuccess: | |||
if err = mc.readResultOK(); err == nil { | |||
return nil // auth successful | |||
} | |||
case cachingSha2PasswordPerformFullAuthentication: | |||
if mc.cfg.tls != nil || mc.cfg.Net == "unix" { | |||
// write cleartext auth packet | |||
err = mc.writeAuthSwitchPacket([]byte(mc.cfg.Passwd), true) | |||
if err != nil { | |||
return err | |||
} | |||
} else { | |||
pubKey := mc.cfg.pubKey | |||
if pubKey == nil { | |||
// request public key from server | |||
data := mc.buf.takeSmallBuffer(4 + 1) | |||
data[4] = cachingSha2PasswordRequestPublicKey | |||
mc.writePacket(data) | |||
// parse public key | |||
data, err := mc.readPacket() | |||
if err != nil { | |||
return err | |||
} | |||
block, _ := pem.Decode(data[1:]) | |||
pkix, err := x509.ParsePKIXPublicKey(block.Bytes) | |||
if err != nil { | |||
return err | |||
} | |||
pubKey = pkix.(*rsa.PublicKey) | |||
} | |||
// send encrypted password | |||
err = mc.sendEncryptedPassword(oldAuthData, pubKey) | |||
if err != nil { | |||
return err | |||
} | |||
} | |||
return mc.readResultOK() | |||
default: | |||
return ErrMalformPkt | |||
} | |||
default: | |||
return ErrMalformPkt | |||
} | |||
case "sha256_password": | |||
switch len(authData) { | |||
case 0: | |||
return nil // auth successful | |||
default: | |||
block, _ := pem.Decode(authData) | |||
pub, err := x509.ParsePKIXPublicKey(block.Bytes) | |||
if err != nil { | |||
return err | |||
} | |||
// send encrypted password | |||
err = mc.sendEncryptedPassword(oldAuthData, pub.(*rsa.PublicKey)) | |||
if err != nil { | |||
return err | |||
} | |||
return mc.readResultOK() | |||
} | |||
default: | |||
return nil // auth successful | |||
} | |||
return err | |||
} |
@@ -130,18 +130,18 @@ func (b *buffer) takeBuffer(length int) []byte { | |||
// smaller than defaultBufSize | |||
// Only one buffer (total) can be used at a time. | |||
func (b *buffer) takeSmallBuffer(length int) []byte { | |||
if b.length == 0 { | |||
return b.buf[:length] | |||
if b.length > 0 { | |||
return nil | |||
} | |||
return nil | |||
return b.buf[:length] | |||
} | |||
// takeCompleteBuffer returns the complete existing buffer. | |||
// This can be used if the necessary buffer size is unknown. | |||
// Only one buffer (total) can be used at a time. | |||
func (b *buffer) takeCompleteBuffer() []byte { | |||
if b.length == 0 { | |||
return b.buf | |||
if b.length > 0 { | |||
return nil | |||
} | |||
return nil | |||
return b.buf | |||
} |
@@ -9,6 +9,7 @@ | |||
package mysql | |||
const defaultCollation = "utf8_general_ci" | |||
const binaryCollation = "binary" | |||
// A list of available collations mapped to the internal ID. | |||
// To update this map use the following MySQL query: | |||
@@ -10,12 +10,23 @@ package mysql | |||
import ( | |||
"database/sql/driver" | |||
"io" | |||
"net" | |||
"strconv" | |||
"strings" | |||
"time" | |||
) | |||
// a copy of context.Context for Go 1.7 and earlier | |||
type mysqlContext interface { | |||
Done() <-chan struct{} | |||
Err() error | |||
// defined in context.Context, but not used in this driver: | |||
// Deadline() (deadline time.Time, ok bool) | |||
// Value(key interface{}) interface{} | |||
} | |||
type mysqlConn struct { | |||
buf buffer | |||
netConn net.Conn | |||
@@ -29,7 +40,14 @@ type mysqlConn struct { | |||
status statusFlag | |||
sequence uint8 | |||
parseTime bool | |||
strict bool | |||
// for context support (Go 1.8+) | |||
watching bool | |||
watcher chan<- mysqlContext | |||
closech chan struct{} | |||
finished chan<- struct{} | |||
canceled atomicError // set non-nil if conn is canceled | |||
closed atomicBool // set when conn is closed, before closech is closed | |||
} | |||
// Handles parameters set in DSN after the connection is established | |||
@@ -62,22 +80,41 @@ func (mc *mysqlConn) handleParams() (err error) { | |||
return | |||
} | |||
func (mc *mysqlConn) markBadConn(err error) error { | |||
if mc == nil { | |||
return err | |||
} | |||
if err != errBadConnNoWrite { | |||
return err | |||
} | |||
return driver.ErrBadConn | |||
} | |||
func (mc *mysqlConn) Begin() (driver.Tx, error) { | |||
if mc.netConn == nil { | |||
return mc.begin(false) | |||
} | |||
func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) { | |||
if mc.closed.IsSet() { | |||
errLog.Print(ErrInvalidConn) | |||
return nil, driver.ErrBadConn | |||
} | |||
err := mc.exec("START TRANSACTION") | |||
var q string | |||
if readOnly { | |||
q = "START TRANSACTION READ ONLY" | |||
} else { | |||
q = "START TRANSACTION" | |||
} | |||
err := mc.exec(q) | |||
if err == nil { | |||
return &mysqlTx{mc}, err | |||
} | |||
return nil, err | |||
return nil, mc.markBadConn(err) | |||
} | |||
func (mc *mysqlConn) Close() (err error) { | |||
// Makes Close idempotent | |||
if mc.netConn != nil { | |||
if !mc.closed.IsSet() { | |||
err = mc.writeCommandPacket(comQuit) | |||
} | |||
@@ -91,26 +128,39 @@ func (mc *mysqlConn) Close() (err error) { | |||
// is called before auth or on auth failure because MySQL will have already | |||
// closed the network connection. | |||
func (mc *mysqlConn) cleanup() { | |||
if !mc.closed.TrySet(true) { | |||
return | |||
} | |||
// Makes cleanup idempotent | |||
if mc.netConn != nil { | |||
if err := mc.netConn.Close(); err != nil { | |||
errLog.Print(err) | |||
close(mc.closech) | |||
if mc.netConn == nil { | |||
return | |||
} | |||
if err := mc.netConn.Close(); err != nil { | |||
errLog.Print(err) | |||
} | |||
} | |||
func (mc *mysqlConn) error() error { | |||
if mc.closed.IsSet() { | |||
if err := mc.canceled.Value(); err != nil { | |||
return err | |||
} | |||
mc.netConn = nil | |||
return ErrInvalidConn | |||
} | |||
mc.cfg = nil | |||
mc.buf.nc = nil | |||
return nil | |||
} | |||
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { | |||
if mc.netConn == nil { | |||
if mc.closed.IsSet() { | |||
errLog.Print(ErrInvalidConn) | |||
return nil, driver.ErrBadConn | |||
} | |||
// Send command | |||
err := mc.writeCommandPacketStr(comStmtPrepare, query) | |||
if err != nil { | |||
return nil, err | |||
return nil, mc.markBadConn(err) | |||
} | |||
stmt := &mysqlStmt{ | |||
@@ -144,7 +194,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin | |||
if buf == nil { | |||
// can not take the buffer. Something must be wrong with the connection | |||
errLog.Print(ErrBusyBuffer) | |||
return "", driver.ErrBadConn | |||
return "", ErrInvalidConn | |||
} | |||
buf = buf[:0] | |||
argPos := 0 | |||
@@ -257,7 +307,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin | |||
} | |||
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { | |||
if mc.netConn == nil { | |||
if mc.closed.IsSet() { | |||
errLog.Print(ErrInvalidConn) | |||
return nil, driver.ErrBadConn | |||
} | |||
@@ -271,7 +321,6 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err | |||
return nil, err | |||
} | |||
query = prepared | |||
args = nil | |||
} | |||
mc.affectedRows = 0 | |||
mc.insertId = 0 | |||
@@ -283,32 +332,43 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err | |||
insertId: int64(mc.insertId), | |||
}, err | |||
} | |||
return nil, err | |||
return nil, mc.markBadConn(err) | |||
} | |||
// Internal function to execute commands | |||
func (mc *mysqlConn) exec(query string) error { | |||
// Send command | |||
err := mc.writeCommandPacketStr(comQuery, query) | |||
if err != nil { | |||
return err | |||
if err := mc.writeCommandPacketStr(comQuery, query); err != nil { | |||
return mc.markBadConn(err) | |||
} | |||
// Read Result | |||
resLen, err := mc.readResultSetHeaderPacket() | |||
if err == nil && resLen > 0 { | |||
if err = mc.readUntilEOF(); err != nil { | |||
if err != nil { | |||
return err | |||
} | |||
if resLen > 0 { | |||
// columns | |||
if err := mc.readUntilEOF(); err != nil { | |||
return err | |||
} | |||
err = mc.readUntilEOF() | |||
// rows | |||
if err := mc.readUntilEOF(); err != nil { | |||
return err | |||
} | |||
} | |||
return err | |||
return mc.discardResults() | |||
} | |||
func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { | |||
if mc.netConn == nil { | |||
return mc.query(query, args) | |||
} | |||
func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) { | |||
if mc.closed.IsSet() { | |||
errLog.Print(ErrInvalidConn) | |||
return nil, driver.ErrBadConn | |||
} | |||
@@ -322,7 +382,6 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro | |||
return nil, err | |||
} | |||
query = prepared | |||
args = nil | |||
} | |||
// Send command | |||
err := mc.writeCommandPacketStr(comQuery, query) | |||
@@ -335,15 +394,22 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro | |||
rows.mc = mc | |||
if resLen == 0 { | |||
// no columns, no more data | |||
return emptyRows{}, nil | |||
rows.rs.done = true | |||
switch err := rows.NextResultSet(); err { | |||
case nil, io.EOF: | |||
return rows, nil | |||
default: | |||
return nil, err | |||
} | |||
} | |||
// Columns | |||
rows.columns, err = mc.readColumns(resLen) | |||
rows.rs.columns, err = mc.readColumns(resLen) | |||
return rows, err | |||
} | |||
} | |||
return nil, err | |||
return nil, mc.markBadConn(err) | |||
} | |||
// Gets the value of the given MySQL System Variable | |||
@@ -359,7 +425,7 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { | |||
if err == nil { | |||
rows := new(textRows) | |||
rows.mc = mc | |||
rows.columns = []mysqlField{{fieldType: fieldTypeVarChar}} | |||
rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}} | |||
if resLen > 0 { | |||
// Columns | |||
@@ -375,3 +441,21 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { | |||
} | |||
return nil, err | |||
} | |||
// finish is called when the query has canceled. | |||
func (mc *mysqlConn) cancel(err error) { | |||
mc.canceled.Set(err) | |||
mc.cleanup() | |||
} | |||
// finish is called when the query has succeeded. | |||
func (mc *mysqlConn) finish() { | |||
if !mc.watching || mc.finished == nil { | |||
return | |||
} | |||
select { | |||
case mc.finished <- struct{}{}: | |||
mc.watching = false | |||
case <-mc.closech: | |||
} | |||
} |
@@ -0,0 +1,208 @@ | |||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package | |||
// | |||
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. | |||
// | |||
// This Source Code Form is subject to the terms of the Mozilla Public | |||
// License, v. 2.0. If a copy of the MPL was not distributed with this file, | |||
// You can obtain one at http://mozilla.org/MPL/2.0/. | |||
// +build go1.8 | |||
package mysql | |||
import ( | |||
"context" | |||
"database/sql" | |||
"database/sql/driver" | |||
) | |||
// Ping implements driver.Pinger interface | |||
func (mc *mysqlConn) Ping(ctx context.Context) (err error) { | |||
if mc.closed.IsSet() { | |||
errLog.Print(ErrInvalidConn) | |||
return driver.ErrBadConn | |||
} | |||
if err = mc.watchCancel(ctx); err != nil { | |||
return | |||
} | |||
defer mc.finish() | |||
if err = mc.writeCommandPacket(comPing); err != nil { | |||
return | |||
} | |||
return mc.readResultOK() | |||
} | |||
// BeginTx implements driver.ConnBeginTx interface | |||
func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { | |||
if err := mc.watchCancel(ctx); err != nil { | |||
return nil, err | |||
} | |||
defer mc.finish() | |||
if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault { | |||
level, err := mapIsolationLevel(opts.Isolation) | |||
if err != nil { | |||
return nil, err | |||
} | |||
err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level) | |||
if err != nil { | |||
return nil, err | |||
} | |||
} | |||
return mc.begin(opts.ReadOnly) | |||
} | |||
func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { | |||
dargs, err := namedValueToValue(args) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if err := mc.watchCancel(ctx); err != nil { | |||
return nil, err | |||
} | |||
rows, err := mc.query(query, dargs) | |||
if err != nil { | |||
mc.finish() | |||
return nil, err | |||
} | |||
rows.finish = mc.finish | |||
return rows, err | |||
} | |||
func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { | |||
dargs, err := namedValueToValue(args) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if err := mc.watchCancel(ctx); err != nil { | |||
return nil, err | |||
} | |||
defer mc.finish() | |||
return mc.Exec(query, dargs) | |||
} | |||
func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { | |||
if err := mc.watchCancel(ctx); err != nil { | |||
return nil, err | |||
} | |||
stmt, err := mc.Prepare(query) | |||
mc.finish() | |||
if err != nil { | |||
return nil, err | |||
} | |||
select { | |||
default: | |||
case <-ctx.Done(): | |||
stmt.Close() | |||
return nil, ctx.Err() | |||
} | |||
return stmt, nil | |||
} | |||
func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { | |||
dargs, err := namedValueToValue(args) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if err := stmt.mc.watchCancel(ctx); err != nil { | |||
return nil, err | |||
} | |||
rows, err := stmt.query(dargs) | |||
if err != nil { | |||
stmt.mc.finish() | |||
return nil, err | |||
} | |||
rows.finish = stmt.mc.finish | |||
return rows, err | |||
} | |||
func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { | |||
dargs, err := namedValueToValue(args) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if err := stmt.mc.watchCancel(ctx); err != nil { | |||
return nil, err | |||
} | |||
defer stmt.mc.finish() | |||
return stmt.Exec(dargs) | |||
} | |||
func (mc *mysqlConn) watchCancel(ctx context.Context) error { | |||
if mc.watching { | |||
// Reach here if canceled, | |||
// so the connection is already invalid | |||
mc.cleanup() | |||
return nil | |||
} | |||
if ctx.Done() == nil { | |||
return nil | |||
} | |||
mc.watching = true | |||
select { | |||
default: | |||
case <-ctx.Done(): | |||
return ctx.Err() | |||
} | |||
if mc.watcher == nil { | |||
return nil | |||
} | |||
mc.watcher <- ctx | |||
return nil | |||
} | |||
func (mc *mysqlConn) startWatcher() { | |||
watcher := make(chan mysqlContext, 1) | |||
mc.watcher = watcher | |||
finished := make(chan struct{}) | |||
mc.finished = finished | |||
go func() { | |||
for { | |||
var ctx mysqlContext | |||
select { | |||
case ctx = <-watcher: | |||
case <-mc.closech: | |||
return | |||
} | |||
select { | |||
case <-ctx.Done(): | |||
mc.cancel(ctx.Err()) | |||
case <-finished: | |||
case <-mc.closech: | |||
return | |||
} | |||
} | |||
}() | |||
} | |||
func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) { | |||
nv.Value, err = converter{}.ConvertValue(nv.Value) | |||
return | |||
} | |||
// ResetSession implements driver.SessionResetter. | |||
// (From Go 1.10) | |||
func (mc *mysqlConn) ResetSession(ctx context.Context) error { | |||
if mc.closed.IsSet() { | |||
return driver.ErrBadConn | |||
} | |||
return nil | |||
} |
@@ -9,7 +9,9 @@ | |||
package mysql | |||
const ( | |||
minProtocolVersion byte = 10 | |||
defaultAuthPlugin = "mysql_native_password" | |||
defaultMaxAllowedPacket = 4 << 20 // 4 MiB | |||
minProtocolVersion = 10 | |||
maxPacketSize = 1<<24 - 1 | |||
timeFormat = "2006-01-02 15:04:05.999999" | |||
) | |||
@@ -18,10 +20,11 @@ const ( | |||
// http://dev.mysql.com/doc/internals/en/client-server-protocol.html | |||
const ( | |||
iOK byte = 0x00 | |||
iLocalInFile byte = 0xfb | |||
iEOF byte = 0xfe | |||
iERR byte = 0xff | |||
iOK byte = 0x00 | |||
iAuthMoreData byte = 0x01 | |||
iLocalInFile byte = 0xfb | |||
iEOF byte = 0xfe | |||
iERR byte = 0xff | |||
) | |||
// https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags | |||
@@ -87,8 +90,10 @@ const ( | |||
) | |||
// https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnType | |||
type fieldType byte | |||
const ( | |||
fieldTypeDecimal byte = iota | |||
fieldTypeDecimal fieldType = iota | |||
fieldTypeTiny | |||
fieldTypeShort | |||
fieldTypeLong | |||
@@ -107,7 +112,7 @@ const ( | |||
fieldTypeBit | |||
) | |||
const ( | |||
fieldTypeJSON byte = iota + 0xf5 | |||
fieldTypeJSON fieldType = iota + 0xf5 | |||
fieldTypeNewDecimal | |||
fieldTypeEnum | |||
fieldTypeSet | |||
@@ -161,3 +166,9 @@ const ( | |||
statusInTransReadonly | |||
statusSessionStateChanged | |||
) | |||
const ( | |||
cachingSha2PasswordRequestPublicKey = 2 | |||
cachingSha2PasswordFastAuthSuccess = 3 | |||
cachingSha2PasswordPerformFullAuthentication = 4 | |||
) |
@@ -4,7 +4,7 @@ | |||
// License, v. 2.0. If a copy of the MPL was not distributed with this file, | |||
// You can obtain one at http://mozilla.org/MPL/2.0/. | |||
// Package mysql provides a MySQL driver for Go's database/sql package | |||
// Package mysql provides a MySQL driver for Go's database/sql package. | |||
// | |||
// The driver should be used via the database/sql package: | |||
// | |||
@@ -20,8 +20,14 @@ import ( | |||
"database/sql" | |||
"database/sql/driver" | |||
"net" | |||
"sync" | |||
) | |||
// watcher interface is used for context support (From Go 1.8) | |||
type watcher interface { | |||
startWatcher() | |||
} | |||
// MySQLDriver is exported to make the driver directly accessible. | |||
// In general the driver is used via the database/sql package. | |||
type MySQLDriver struct{} | |||
@@ -30,12 +36,17 @@ type MySQLDriver struct{} | |||
// Custom dial functions must be registered with RegisterDial | |||
type DialFunc func(addr string) (net.Conn, error) | |||
var dials map[string]DialFunc | |||
var ( | |||
dialsLock sync.RWMutex | |||
dials map[string]DialFunc | |||
) | |||
// RegisterDial registers a custom dial function. It can then be used by the | |||
// network address mynet(addr), where mynet is the registered new network. | |||
// addr is passed as a parameter to the dial function. | |||
func RegisterDial(net string, dial DialFunc) { | |||
dialsLock.Lock() | |||
defer dialsLock.Unlock() | |||
if dials == nil { | |||
dials = make(map[string]DialFunc) | |||
} | |||
@@ -52,16 +63,19 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { | |||
mc := &mysqlConn{ | |||
maxAllowedPacket: maxPacketSize, | |||
maxWriteSize: maxPacketSize - 1, | |||
closech: make(chan struct{}), | |||
} | |||
mc.cfg, err = ParseDSN(dsn) | |||
if err != nil { | |||
return nil, err | |||
} | |||
mc.parseTime = mc.cfg.ParseTime | |||
mc.strict = mc.cfg.Strict | |||
// Connect to Server | |||
if dial, ok := dials[mc.cfg.Net]; ok { | |||
dialsLock.RLock() | |||
dial, ok := dials[mc.cfg.Net] | |||
dialsLock.RUnlock() | |||
if ok { | |||
mc.netConn, err = dial(mc.cfg.Addr) | |||
} else { | |||
nd := net.Dialer{Timeout: mc.cfg.Timeout} | |||
@@ -81,6 +95,11 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { | |||
} | |||
} | |||
// Call startWatcher for context support (From Go 1.8) | |||
if s, ok := interface{}(mc).(watcher); ok { | |||
s.startWatcher() | |||
} | |||
mc.buf = newBuffer(mc.netConn) | |||
// Set I/O timeouts | |||
@@ -88,20 +107,31 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { | |||
mc.writeTimeout = mc.cfg.WriteTimeout | |||
// Reading Handshake Initialization Packet | |||
cipher, err := mc.readInitPacket() | |||
authData, plugin, err := mc.readHandshakePacket() | |||
if err != nil { | |||
mc.cleanup() | |||
return nil, err | |||
} | |||
// Send Client Authentication Packet | |||
if err = mc.writeAuthPacket(cipher); err != nil { | |||
authResp, addNUL, err := mc.auth(authData, plugin) | |||
if err != nil { | |||
// try the default auth plugin, if using the requested plugin failed | |||
errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error()) | |||
plugin = defaultAuthPlugin | |||
authResp, addNUL, err = mc.auth(authData, plugin) | |||
if err != nil { | |||
mc.cleanup() | |||
return nil, err | |||
} | |||
} | |||
if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil { | |||
mc.cleanup() | |||
return nil, err | |||
} | |||
// Handle response to auth packet, switch methods if possible | |||
if err = handleAuthResult(mc); err != nil { | |||
if err = mc.handleAuthResult(authData, plugin); err != nil { | |||
// Authentication failed and MySQL has already closed the connection | |||
// (https://dev.mysql.com/doc/internals/en/authentication-fails.html). | |||
// Do not send COM_QUIT, just cleanup and return the error. | |||
@@ -134,43 +164,6 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { | |||
return mc, nil | |||
} | |||
func handleAuthResult(mc *mysqlConn) error { | |||
// Read Result Packet | |||
cipher, err := mc.readResultOK() | |||
if err == nil { | |||
return nil // auth successful | |||
} | |||
if mc.cfg == nil { | |||
return err // auth failed and retry not possible | |||
} | |||
// Retry auth if configured to do so. | |||
if mc.cfg.AllowOldPasswords && err == ErrOldPassword { | |||
// Retry with old authentication method. Note: there are edge cases | |||
// where this should work but doesn't; this is currently "wontfix": | |||
// https://github.com/go-sql-driver/mysql/issues/184 | |||
if err = mc.writeOldAuthPacket(cipher); err != nil { | |||
return err | |||
} | |||
_, err = mc.readResultOK() | |||
} else if mc.cfg.AllowCleartextPasswords && err == ErrCleartextPassword { | |||
// Retry with clear text password for | |||
// http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html | |||
// http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html | |||
if err = mc.writeClearAuthPacket(); err != nil { | |||
return err | |||
} | |||
_, err = mc.readResultOK() | |||
} else if mc.cfg.AllowNativePasswords && err == ErrNativePassword { | |||
if err = mc.writeNativeAuthPacket(cipher); err != nil { | |||
return err | |||
} | |||
_, err = mc.readResultOK() | |||
} | |||
return err | |||
} | |||
func init() { | |||
sql.Register("mysql", &MySQLDriver{}) | |||
} |
@@ -10,11 +10,13 @@ package mysql | |||
import ( | |||
"bytes" | |||
"crypto/rsa" | |||
"crypto/tls" | |||
"errors" | |||
"fmt" | |||
"net" | |||
"net/url" | |||
"sort" | |||
"strconv" | |||
"strings" | |||
"time" | |||
@@ -27,7 +29,9 @@ var ( | |||
errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations") | |||
) | |||
// Config is a configuration parsed from a DSN string | |||
// Config is a configuration parsed from a DSN string. | |||
// If a new Config is created instead of being parsed from a DSN string, | |||
// the NewConfig function should be used, which sets default values. | |||
type Config struct { | |||
User string // Username | |||
Passwd string // Password (requires User) | |||
@@ -38,6 +42,8 @@ type Config struct { | |||
Collation string // Connection collation | |||
Loc *time.Location // Location for time.Time values | |||
MaxAllowedPacket int // Max packet size allowed | |||
ServerPubKey string // Server public key name | |||
pubKey *rsa.PublicKey // Server public key | |||
TLSConfig string // TLS configuration name | |||
tls *tls.Config // TLS configuration | |||
Timeout time.Duration // Dial timeout | |||
@@ -53,7 +59,54 @@ type Config struct { | |||
InterpolateParams bool // Interpolate placeholders into query string | |||
MultiStatements bool // Allow multiple statements in one query | |||
ParseTime bool // Parse time values to time.Time | |||
Strict bool // Return warnings as errors | |||
RejectReadOnly bool // Reject read-only connections | |||
} | |||
// NewConfig creates a new Config and sets default values. | |||
func NewConfig() *Config { | |||
return &Config{ | |||
Collation: defaultCollation, | |||
Loc: time.UTC, | |||
MaxAllowedPacket: defaultMaxAllowedPacket, | |||
AllowNativePasswords: true, | |||
} | |||
} | |||
func (cfg *Config) normalize() error { | |||
if cfg.InterpolateParams && unsafeCollations[cfg.Collation] { | |||
return errInvalidDSNUnsafeCollation | |||
} | |||
// Set default network if empty | |||
if cfg.Net == "" { | |||
cfg.Net = "tcp" | |||
} | |||
// Set default address if empty | |||
if cfg.Addr == "" { | |||
switch cfg.Net { | |||
case "tcp": | |||
cfg.Addr = "127.0.0.1:3306" | |||
case "unix": | |||
cfg.Addr = "/tmp/mysql.sock" | |||
default: | |||
return errors.New("default addr for network '" + cfg.Net + "' unknown") | |||
} | |||
} else if cfg.Net == "tcp" { | |||
cfg.Addr = ensureHavePort(cfg.Addr) | |||
} | |||
if cfg.tls != nil { | |||
if cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify { | |||
host, _, err := net.SplitHostPort(cfg.Addr) | |||
if err == nil { | |||
cfg.tls.ServerName = host | |||
} | |||
} | |||
} | |||
return nil | |||
} | |||
// FormatDSN formats the given Config into a DSN string which can be passed to | |||
@@ -102,12 +155,12 @@ func (cfg *Config) FormatDSN() string { | |||
} | |||
} | |||
if cfg.AllowNativePasswords { | |||
if !cfg.AllowNativePasswords { | |||
if hasParam { | |||
buf.WriteString("&allowNativePasswords=true") | |||
buf.WriteString("&allowNativePasswords=false") | |||
} else { | |||
hasParam = true | |||
buf.WriteString("?allowNativePasswords=true") | |||
buf.WriteString("?allowNativePasswords=false") | |||
} | |||
} | |||
@@ -195,15 +248,25 @@ func (cfg *Config) FormatDSN() string { | |||
buf.WriteString(cfg.ReadTimeout.String()) | |||
} | |||
if cfg.Strict { | |||
if cfg.RejectReadOnly { | |||
if hasParam { | |||
buf.WriteString("&strict=true") | |||
buf.WriteString("&rejectReadOnly=true") | |||
} else { | |||
hasParam = true | |||
buf.WriteString("?strict=true") | |||
buf.WriteString("?rejectReadOnly=true") | |||
} | |||
} | |||
if len(cfg.ServerPubKey) > 0 { | |||
if hasParam { | |||
buf.WriteString("&serverPubKey=") | |||
} else { | |||
hasParam = true | |||
buf.WriteString("?serverPubKey=") | |||
} | |||
buf.WriteString(url.QueryEscape(cfg.ServerPubKey)) | |||
} | |||
if cfg.Timeout > 0 { | |||
if hasParam { | |||
buf.WriteString("&timeout=") | |||
@@ -234,7 +297,7 @@ func (cfg *Config) FormatDSN() string { | |||
buf.WriteString(cfg.WriteTimeout.String()) | |||
} | |||
if cfg.MaxAllowedPacket > 0 { | |||
if cfg.MaxAllowedPacket != defaultMaxAllowedPacket { | |||
if hasParam { | |||
buf.WriteString("&maxAllowedPacket=") | |||
} else { | |||
@@ -247,7 +310,12 @@ func (cfg *Config) FormatDSN() string { | |||
// other params | |||
if cfg.Params != nil { | |||
for param, value := range cfg.Params { | |||
var params []string | |||
for param := range cfg.Params { | |||
params = append(params, param) | |||
} | |||
sort.Strings(params) | |||
for _, param := range params { | |||
if hasParam { | |||
buf.WriteByte('&') | |||
} else { | |||
@@ -257,7 +325,7 @@ func (cfg *Config) FormatDSN() string { | |||
buf.WriteString(param) | |||
buf.WriteByte('=') | |||
buf.WriteString(url.QueryEscape(value)) | |||
buf.WriteString(url.QueryEscape(cfg.Params[param])) | |||
} | |||
} | |||
@@ -267,10 +335,7 @@ func (cfg *Config) FormatDSN() string { | |||
// ParseDSN parses the DSN string to a Config | |||
func ParseDSN(dsn string) (cfg *Config, err error) { | |||
// New config with some default values | |||
cfg = &Config{ | |||
Loc: time.UTC, | |||
Collation: defaultCollation, | |||
} | |||
cfg = NewConfig() | |||
// [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] | |||
// Find the last '/' (since the password or the net addr might contain a '/') | |||
@@ -338,28 +403,9 @@ func ParseDSN(dsn string) (cfg *Config, err error) { | |||
return nil, errInvalidDSNNoSlash | |||
} | |||
if cfg.InterpolateParams && unsafeCollations[cfg.Collation] { | |||
return nil, errInvalidDSNUnsafeCollation | |||
} | |||
// Set default network if empty | |||
if cfg.Net == "" { | |||
cfg.Net = "tcp" | |||
if err = cfg.normalize(); err != nil { | |||
return nil, err | |||
} | |||
// Set default address if empty | |||
if cfg.Addr == "" { | |||
switch cfg.Net { | |||
case "tcp": | |||
cfg.Addr = "127.0.0.1:3306" | |||
case "unix": | |||
cfg.Addr = "/tmp/mysql.sock" | |||
default: | |||
return nil, errors.New("default addr for network '" + cfg.Net + "' unknown") | |||
} | |||
} | |||
return | |||
} | |||
@@ -374,7 +420,6 @@ func parseDSNParams(cfg *Config, params string) (err error) { | |||
// cfg params | |||
switch value := param[1]; param[0] { | |||
// Disable INFILE whitelist / enable all files | |||
case "allowAllFiles": | |||
var isBool bool | |||
@@ -472,14 +517,32 @@ func parseDSNParams(cfg *Config, params string) (err error) { | |||
return | |||
} | |||
// Strict mode | |||
case "strict": | |||
// Reject read-only connections | |||
case "rejectReadOnly": | |||
var isBool bool | |||
cfg.Strict, isBool = readBool(value) | |||
cfg.RejectReadOnly, isBool = readBool(value) | |||
if !isBool { | |||
return errors.New("invalid bool value: " + value) | |||
} | |||
// Server public key | |||
case "serverPubKey": | |||
name, err := url.QueryUnescape(value) | |||
if err != nil { | |||
return fmt.Errorf("invalid value for server pub key name: %v", err) | |||
} | |||
if pubKey := getServerPubKey(name); pubKey != nil { | |||
cfg.ServerPubKey = name | |||
cfg.pubKey = pubKey | |||
} else { | |||
return errors.New("invalid value / unknown server pub key name: " + name) | |||
} | |||
// Strict mode | |||
case "strict": | |||
panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode") | |||
// Dial Timeout | |||
case "timeout": | |||
cfg.Timeout, err = time.ParseDuration(value) | |||
@@ -506,14 +569,7 @@ func parseDSNParams(cfg *Config, params string) (err error) { | |||
return fmt.Errorf("invalid value for TLS config name: %v", err) | |||
} | |||
if tlsConfig, ok := tlsConfigRegister[name]; ok { | |||
if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify { | |||
host, _, err := net.SplitHostPort(cfg.Addr) | |||
if err == nil { | |||
tlsConfig.ServerName = host | |||
} | |||
} | |||
if tlsConfig := getTLSConfigClone(name); tlsConfig != nil { | |||
cfg.TLSConfig = name | |||
cfg.tls = tlsConfig | |||
} else { | |||
@@ -546,3 +602,10 @@ func parseDSNParams(cfg *Config, params string) (err error) { | |||
return | |||
} | |||
func ensureHavePort(addr string) string { | |||
if _, _, err := net.SplitHostPort(addr); err != nil { | |||
return net.JoinHostPort(addr, "3306") | |||
} | |||
return addr | |||
} |
@@ -9,10 +9,8 @@ | |||
package mysql | |||
import ( | |||
"database/sql/driver" | |||
"errors" | |||
"fmt" | |||
"io" | |||
"log" | |||
"os" | |||
) | |||
@@ -31,6 +29,12 @@ var ( | |||
ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?") | |||
ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server") | |||
ErrBusyBuffer = errors.New("busy buffer") | |||
// errBadConnNoWrite is used for connection errors where nothing was sent to the database yet. | |||
// If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn | |||
// to trigger a resend. | |||
// See https://github.com/go-sql-driver/mysql/pull/302 | |||
errBadConnNoWrite = errors.New("bad connection") | |||
) | |||
var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile)) | |||
@@ -59,74 +63,3 @@ type MySQLError struct { | |||
func (me *MySQLError) Error() string { | |||
return fmt.Sprintf("Error %d: %s", me.Number, me.Message) | |||
} | |||
// MySQLWarnings is an error type which represents a group of one or more MySQL | |||
// warnings | |||
type MySQLWarnings []MySQLWarning | |||
func (mws MySQLWarnings) Error() string { | |||
var msg string | |||
for i, warning := range mws { | |||
if i > 0 { | |||
msg += "\r\n" | |||
} | |||
msg += fmt.Sprintf( | |||
"%s %s: %s", | |||
warning.Level, | |||
warning.Code, | |||
warning.Message, | |||
) | |||
} | |||
return msg | |||
} | |||
// MySQLWarning is an error type which represents a single MySQL warning. | |||
// Warnings are returned in groups only. See MySQLWarnings | |||
type MySQLWarning struct { | |||
Level string | |||
Code string | |||
Message string | |||
} | |||
func (mc *mysqlConn) getWarnings() (err error) { | |||
rows, err := mc.Query("SHOW WARNINGS", nil) | |||
if err != nil { | |||
return | |||
} | |||
var warnings = MySQLWarnings{} | |||
var values = make([]driver.Value, 3) | |||
for { | |||
err = rows.Next(values) | |||
switch err { | |||
case nil: | |||
warning := MySQLWarning{} | |||
if raw, ok := values[0].([]byte); ok { | |||
warning.Level = string(raw) | |||
} else { | |||
warning.Level = fmt.Sprintf("%s", values[0]) | |||
} | |||
if raw, ok := values[1].([]byte); ok { | |||
warning.Code = string(raw) | |||
} else { | |||
warning.Code = fmt.Sprintf("%s", values[1]) | |||
} | |||
if raw, ok := values[2].([]byte); ok { | |||
warning.Message = string(raw) | |||
} else { | |||
warning.Message = fmt.Sprintf("%s", values[0]) | |||
} | |||
warnings = append(warnings, warning) | |||
case io.EOF: | |||
return warnings | |||
default: | |||
rows.Close() | |||
return | |||
} | |||
} | |||
} |
@@ -0,0 +1,194 @@ | |||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package | |||
// | |||
// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. | |||
// | |||
// This Source Code Form is subject to the terms of the Mozilla Public | |||
// License, v. 2.0. If a copy of the MPL was not distributed with this file, | |||
// You can obtain one at http://mozilla.org/MPL/2.0/. | |||
package mysql | |||
import ( | |||
"database/sql" | |||
"reflect" | |||
) | |||
func (mf *mysqlField) typeDatabaseName() string { | |||
switch mf.fieldType { | |||
case fieldTypeBit: | |||
return "BIT" | |||
case fieldTypeBLOB: | |||
if mf.charSet != collations[binaryCollation] { | |||
return "TEXT" | |||
} | |||
return "BLOB" | |||
case fieldTypeDate: | |||
return "DATE" | |||
case fieldTypeDateTime: | |||
return "DATETIME" | |||
case fieldTypeDecimal: | |||
return "DECIMAL" | |||
case fieldTypeDouble: | |||
return "DOUBLE" | |||
case fieldTypeEnum: | |||
return "ENUM" | |||
case fieldTypeFloat: | |||
return "FLOAT" | |||
case fieldTypeGeometry: | |||
return "GEOMETRY" | |||
case fieldTypeInt24: | |||
return "MEDIUMINT" | |||
case fieldTypeJSON: | |||
return "JSON" | |||
case fieldTypeLong: | |||
return "INT" | |||
case fieldTypeLongBLOB: | |||
if mf.charSet != collations[binaryCollation] { | |||
return "LONGTEXT" | |||
} | |||
return "LONGBLOB" | |||
case fieldTypeLongLong: | |||
return "BIGINT" | |||
case fieldTypeMediumBLOB: | |||
if mf.charSet != collations[binaryCollation] { | |||
return "MEDIUMTEXT" | |||
} | |||
return "MEDIUMBLOB" | |||
case fieldTypeNewDate: | |||
return "DATE" | |||
case fieldTypeNewDecimal: | |||
return "DECIMAL" | |||
case fieldTypeNULL: | |||
return "NULL" | |||
case fieldTypeSet: | |||
return "SET" | |||
case fieldTypeShort: | |||
return "SMALLINT" | |||
case fieldTypeString: | |||
if mf.charSet == collations[binaryCollation] { | |||
return "BINARY" | |||
} | |||
return "CHAR" | |||
case fieldTypeTime: | |||
return "TIME" | |||
case fieldTypeTimestamp: | |||
return "TIMESTAMP" | |||
case fieldTypeTiny: | |||
return "TINYINT" | |||
case fieldTypeTinyBLOB: | |||
if mf.charSet != collations[binaryCollation] { | |||
return "TINYTEXT" | |||
} | |||
return "TINYBLOB" | |||
case fieldTypeVarChar: | |||
if mf.charSet == collations[binaryCollation] { | |||
return "VARBINARY" | |||
} | |||
return "VARCHAR" | |||
case fieldTypeVarString: | |||
if mf.charSet == collations[binaryCollation] { | |||
return "VARBINARY" | |||
} | |||
return "VARCHAR" | |||
case fieldTypeYear: | |||
return "YEAR" | |||
default: | |||
return "" | |||
} | |||
} | |||
var ( | |||
scanTypeFloat32 = reflect.TypeOf(float32(0)) | |||
scanTypeFloat64 = reflect.TypeOf(float64(0)) | |||
scanTypeInt8 = reflect.TypeOf(int8(0)) | |||
scanTypeInt16 = reflect.TypeOf(int16(0)) | |||
scanTypeInt32 = reflect.TypeOf(int32(0)) | |||
scanTypeInt64 = reflect.TypeOf(int64(0)) | |||
scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{}) | |||
scanTypeNullInt = reflect.TypeOf(sql.NullInt64{}) | |||
scanTypeNullTime = reflect.TypeOf(NullTime{}) | |||
scanTypeUint8 = reflect.TypeOf(uint8(0)) | |||
scanTypeUint16 = reflect.TypeOf(uint16(0)) | |||
scanTypeUint32 = reflect.TypeOf(uint32(0)) | |||
scanTypeUint64 = reflect.TypeOf(uint64(0)) | |||
scanTypeRawBytes = reflect.TypeOf(sql.RawBytes{}) | |||
scanTypeUnknown = reflect.TypeOf(new(interface{})) | |||
) | |||
type mysqlField struct { | |||
tableName string | |||
name string | |||
length uint32 | |||
flags fieldFlag | |||
fieldType fieldType | |||
decimals byte | |||
charSet uint8 | |||
} | |||
func (mf *mysqlField) scanType() reflect.Type { | |||
switch mf.fieldType { | |||
case fieldTypeTiny: | |||
if mf.flags&flagNotNULL != 0 { | |||
if mf.flags&flagUnsigned != 0 { | |||
return scanTypeUint8 | |||
} | |||
return scanTypeInt8 | |||
} | |||
return scanTypeNullInt | |||
case fieldTypeShort, fieldTypeYear: | |||
if mf.flags&flagNotNULL != 0 { | |||
if mf.flags&flagUnsigned != 0 { | |||
return scanTypeUint16 | |||
} | |||
return scanTypeInt16 | |||
} | |||
return scanTypeNullInt | |||
case fieldTypeInt24, fieldTypeLong: | |||
if mf.flags&flagNotNULL != 0 { | |||
if mf.flags&flagUnsigned != 0 { | |||
return scanTypeUint32 | |||
} | |||
return scanTypeInt32 | |||
} | |||
return scanTypeNullInt | |||
case fieldTypeLongLong: | |||
if mf.flags&flagNotNULL != 0 { | |||
if mf.flags&flagUnsigned != 0 { | |||
return scanTypeUint64 | |||
} | |||
return scanTypeInt64 | |||
} | |||
return scanTypeNullInt | |||
case fieldTypeFloat: | |||
if mf.flags&flagNotNULL != 0 { | |||
return scanTypeFloat32 | |||
} | |||
return scanTypeNullFloat | |||
case fieldTypeDouble: | |||
if mf.flags&flagNotNULL != 0 { | |||
return scanTypeFloat64 | |||
} | |||
return scanTypeNullFloat | |||
case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar, | |||
fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB, | |||
fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, | |||
fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON, | |||
fieldTypeTime: | |||
return scanTypeRawBytes | |||
case fieldTypeDate, fieldTypeNewDate, | |||
fieldTypeTimestamp, fieldTypeDateTime: | |||
// NullTime is always returned for more consistent behavior as it can | |||
// handle both cases of parseTime regardless if the field is nullable. | |||
return scanTypeNullTime | |||
default: | |||
return scanTypeUnknown | |||
} | |||
} |
@@ -147,7 +147,8 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { | |||
} | |||
// send content packets | |||
if err == nil { | |||
// if packetSize == 0, the Reader contains no data | |||
if err == nil && packetSize > 0 { | |||
data := make([]byte, 4+packetSize) | |||
var n int | |||
for err == nil { | |||
@@ -173,8 +174,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { | |||
// read OK packet | |||
if err == nil { | |||
_, err = mc.readResultOK() | |||
return err | |||
return mc.readResultOK() | |||
} | |||
mc.readPacket() | |||
@@ -25,26 +25,23 @@ import ( | |||
// Read packet to buffer 'data' | |||
func (mc *mysqlConn) readPacket() ([]byte, error) { | |||
var payload []byte | |||
var prevData []byte | |||
for { | |||
// Read packet header | |||
// read packet header | |||
data, err := mc.buf.readNext(4) | |||
if err != nil { | |||
if cerr := mc.canceled.Value(); cerr != nil { | |||
return nil, cerr | |||
} | |||
errLog.Print(err) | |||
mc.Close() | |||
return nil, driver.ErrBadConn | |||
return nil, ErrInvalidConn | |||
} | |||
// Packet Length [24 bit] | |||
// packet length [24 bit] | |||
pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) | |||
if pktLen < 1 { | |||
errLog.Print(ErrMalformPkt) | |||
mc.Close() | |||
return nil, driver.ErrBadConn | |||
} | |||
// Check Packet Sync [8 bit] | |||
// check packet sync [8 bit] | |||
if data[3] != mc.sequence { | |||
if data[3] > mc.sequence { | |||
return nil, ErrPktSyncMul | |||
@@ -53,26 +50,41 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { | |||
} | |||
mc.sequence++ | |||
// Read packet body [pktLen bytes] | |||
// packets with length 0 terminate a previous packet which is a | |||
// multiple of (2^24)−1 bytes long | |||
if pktLen == 0 { | |||
// there was no previous packet | |||
if prevData == nil { | |||
errLog.Print(ErrMalformPkt) | |||
mc.Close() | |||
return nil, ErrInvalidConn | |||
} | |||
return prevData, nil | |||
} | |||
// read packet body [pktLen bytes] | |||
data, err = mc.buf.readNext(pktLen) | |||
if err != nil { | |||
if cerr := mc.canceled.Value(); cerr != nil { | |||
return nil, cerr | |||
} | |||
errLog.Print(err) | |||
mc.Close() | |||
return nil, driver.ErrBadConn | |||
return nil, ErrInvalidConn | |||
} | |||
isLastPacket := (pktLen < maxPacketSize) | |||
// return data if this was the last packet | |||
if pktLen < maxPacketSize { | |||
// zero allocations for non-split packets | |||
if prevData == nil { | |||
return data, nil | |||
} | |||
// Zero allocations for non-splitting packets | |||
if isLastPacket && payload == nil { | |||
return data, nil | |||
return append(prevData, data...), nil | |||
} | |||
payload = append(payload, data...) | |||
if isLastPacket { | |||
return payload, nil | |||
} | |||
prevData = append(prevData, data...) | |||
} | |||
} | |||
@@ -119,33 +131,47 @@ func (mc *mysqlConn) writePacket(data []byte) error { | |||
// Handle error | |||
if err == nil { // n != len(data) | |||
mc.cleanup() | |||
errLog.Print(ErrMalformPkt) | |||
} else { | |||
if cerr := mc.canceled.Value(); cerr != nil { | |||
return cerr | |||
} | |||
if n == 0 && pktLen == len(data)-4 { | |||
// only for the first loop iteration when nothing was written yet | |||
return errBadConnNoWrite | |||
} | |||
mc.cleanup() | |||
errLog.Print(err) | |||
} | |||
return driver.ErrBadConn | |||
return ErrInvalidConn | |||
} | |||
} | |||
/****************************************************************************** | |||
* Initialisation Process * | |||
* Initialization Process * | |||
******************************************************************************/ | |||
// Handshake Initialization Packet | |||
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake | |||
func (mc *mysqlConn) readInitPacket() ([]byte, error) { | |||
func (mc *mysqlConn) readHandshakePacket() ([]byte, string, error) { | |||
data, err := mc.readPacket() | |||
if err != nil { | |||
return nil, err | |||
// for init we can rewrite this to ErrBadConn for sql.Driver to retry, since | |||
// in connection initialization we don't risk retrying non-idempotent actions. | |||
if err == ErrInvalidConn { | |||
return nil, "", driver.ErrBadConn | |||
} | |||
return nil, "", err | |||
} | |||
if data[0] == iERR { | |||
return nil, mc.handleErrorPacket(data) | |||
return nil, "", mc.handleErrorPacket(data) | |||
} | |||
// protocol version [1 byte] | |||
if data[0] < minProtocolVersion { | |||
return nil, fmt.Errorf( | |||
return nil, "", fmt.Errorf( | |||
"unsupported protocol version %d. Version %d or higher is required", | |||
data[0], | |||
minProtocolVersion, | |||
@@ -157,7 +183,7 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { | |||
pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4 | |||
// first part of the password cipher [8 bytes] | |||
cipher := data[pos : pos+8] | |||
authData := data[pos : pos+8] | |||
// (filler) always 0x00 [1 byte] | |||
pos += 8 + 1 | |||
@@ -165,13 +191,14 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { | |||
// capability flags (lower 2 bytes) [2 bytes] | |||
mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) | |||
if mc.flags&clientProtocol41 == 0 { | |||
return nil, ErrOldProtocol | |||
return nil, "", ErrOldProtocol | |||
} | |||
if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { | |||
return nil, ErrNoTLS | |||
return nil, "", ErrNoTLS | |||
} | |||
pos += 2 | |||
plugin := "" | |||
if len(data) > pos { | |||
// character set [1 byte] | |||
// status flags [2 bytes] | |||
@@ -192,32 +219,34 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { | |||
// | |||
// The official Python library uses the fixed length 12 | |||
// which seems to work but technically could have a hidden bug. | |||
cipher = append(cipher, data[pos:pos+12]...) | |||
authData = append(authData, data[pos:pos+12]...) | |||
pos += 13 | |||
// TODO: Verify string termination | |||
// EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2) | |||
// \NUL otherwise | |||
// | |||
//if data[len(data)-1] == 0 { | |||
// return | |||
//} | |||
//return ErrMalformPkt | |||
if end := bytes.IndexByte(data[pos:], 0x00); end != -1 { | |||
plugin = string(data[pos : pos+end]) | |||
} else { | |||
plugin = string(data[pos:]) | |||
} | |||
// make a memory safe copy of the cipher slice | |||
var b [20]byte | |||
copy(b[:], cipher) | |||
return b[:], nil | |||
copy(b[:], authData) | |||
return b[:], plugin, nil | |||
} | |||
plugin = defaultAuthPlugin | |||
// make a memory safe copy of the cipher slice | |||
var b [8]byte | |||
copy(b[:], cipher) | |||
return b[:], nil | |||
copy(b[:], authData) | |||
return b[:], plugin, nil | |||
} | |||
// Client Authentication Packet | |||
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse | |||
func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { | |||
func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, plugin string) error { | |||
// Adjust client flags based on server support | |||
clientFlags := clientProtocol41 | | |||
clientSecureConn | | |||
@@ -241,10 +270,19 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { | |||
clientFlags |= clientMultiStatements | |||
} | |||
// User Password | |||
scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd)) | |||
// encode length of the auth plugin data | |||
var authRespLEIBuf [9]byte | |||
authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(authResp))) | |||
if len(authRespLEI) > 1 { | |||
// if the length can not be written in 1 byte, it must be written as a | |||
// length encoded integer | |||
clientFlags |= clientPluginAuthLenEncClientData | |||
} | |||
pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(scrambleBuff) + 21 + 1 | |||
pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1 | |||
if addNUL { | |||
pktLen++ | |||
} | |||
// To specify a db name | |||
if n := len(mc.cfg.DBName); n > 0 { | |||
@@ -255,9 +293,9 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { | |||
// Calculate packet length and get buffer with that size | |||
data := mc.buf.takeSmallBuffer(pktLen + 4) | |||
if data == nil { | |||
// can not take the buffer. Something must be wrong with the connection | |||
// cannot take the buffer. Something must be wrong with the connection | |||
errLog.Print(ErrBusyBuffer) | |||
return driver.ErrBadConn | |||
return errBadConnNoWrite | |||
} | |||
// ClientFlags [32 bit] | |||
@@ -312,9 +350,13 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { | |||
data[pos] = 0x00 | |||
pos++ | |||
// ScrambleBuffer [length encoded integer] | |||
data[pos] = byte(len(scrambleBuff)) | |||
pos += 1 + copy(data[pos+1:], scrambleBuff) | |||
// Auth Data [length encoded integer] | |||
pos += copy(data[pos:], authRespLEI) | |||
pos += copy(data[pos:], authResp) | |||
if addNUL { | |||
data[pos] = 0x00 | |||
pos++ | |||
} | |||
// Databasename [null terminated string] | |||
if len(mc.cfg.DBName) > 0 { | |||
@@ -323,72 +365,32 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { | |||
pos++ | |||
} | |||
// Assume native client during response | |||
pos += copy(data[pos:], "mysql_native_password") | |||
pos += copy(data[pos:], plugin) | |||
data[pos] = 0x00 | |||
// Send Auth packet | |||
return mc.writePacket(data) | |||
} | |||
// Client old authentication packet | |||
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse | |||
func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { | |||
// User password | |||
scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.Passwd)) | |||
// Calculate the packet length and add a tailing 0 | |||
pktLen := len(scrambleBuff) + 1 | |||
data := mc.buf.takeSmallBuffer(4 + pktLen) | |||
if data == nil { | |||
// can not take the buffer. Something must be wrong with the connection | |||
errLog.Print(ErrBusyBuffer) | |||
return driver.ErrBadConn | |||
func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte, addNUL bool) error { | |||
pktLen := 4 + len(authData) | |||
if addNUL { | |||
pktLen++ | |||
} | |||
// Add the scrambled password [null terminated string] | |||
copy(data[4:], scrambleBuff) | |||
data[4+pktLen-1] = 0x00 | |||
return mc.writePacket(data) | |||
} | |||
// Client clear text authentication packet | |||
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse | |||
func (mc *mysqlConn) writeClearAuthPacket() error { | |||
// Calculate the packet length and add a tailing 0 | |||
pktLen := len(mc.cfg.Passwd) + 1 | |||
data := mc.buf.takeSmallBuffer(4 + pktLen) | |||
data := mc.buf.takeSmallBuffer(pktLen) | |||
if data == nil { | |||
// can not take the buffer. Something must be wrong with the connection | |||
// cannot take the buffer. Something must be wrong with the connection | |||
errLog.Print(ErrBusyBuffer) | |||
return driver.ErrBadConn | |||
return errBadConnNoWrite | |||
} | |||
// Add the clear password [null terminated string] | |||
copy(data[4:], mc.cfg.Passwd) | |||
data[4+pktLen-1] = 0x00 | |||
return mc.writePacket(data) | |||
} | |||
// Native password authentication method | |||
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse | |||
func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error { | |||
scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd)) | |||
// Calculate the packet length and add a tailing 0 | |||
pktLen := len(scrambleBuff) | |||
data := mc.buf.takeSmallBuffer(4 + pktLen) | |||
if data == nil { | |||
// can not take the buffer. Something must be wrong with the connection | |||
errLog.Print(ErrBusyBuffer) | |||
return driver.ErrBadConn | |||
// Add the auth data [EOF] | |||
copy(data[4:], authData) | |||
if addNUL { | |||
data[pktLen-1] = 0x00 | |||
} | |||
// Add the scramble | |||
copy(data[4:], scrambleBuff) | |||
return mc.writePacket(data) | |||
} | |||
@@ -402,9 +404,9 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { | |||
data := mc.buf.takeSmallBuffer(4 + 1) | |||
if data == nil { | |||
// can not take the buffer. Something must be wrong with the connection | |||
// cannot take the buffer. Something must be wrong with the connection | |||
errLog.Print(ErrBusyBuffer) | |||
return driver.ErrBadConn | |||
return errBadConnNoWrite | |||
} | |||
// Add command byte | |||
@@ -421,9 +423,9 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { | |||
pktLen := 1 + len(arg) | |||
data := mc.buf.takeBuffer(pktLen + 4) | |||
if data == nil { | |||
// can not take the buffer. Something must be wrong with the connection | |||
// cannot take the buffer. Something must be wrong with the connection | |||
errLog.Print(ErrBusyBuffer) | |||
return driver.ErrBadConn | |||
return errBadConnNoWrite | |||
} | |||
// Add command byte | |||
@@ -442,9 +444,9 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { | |||
data := mc.buf.takeSmallBuffer(4 + 1 + 4) | |||
if data == nil { | |||
// can not take the buffer. Something must be wrong with the connection | |||
// cannot take the buffer. Something must be wrong with the connection | |||
errLog.Print(ErrBusyBuffer) | |||
return driver.ErrBadConn | |||
return errBadConnNoWrite | |||
} | |||
// Add command byte | |||
@@ -464,43 +466,50 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { | |||
* Result Packets * | |||
******************************************************************************/ | |||
// Returns error if Packet is not an 'Result OK'-Packet | |||
func (mc *mysqlConn) readResultOK() ([]byte, error) { | |||
func (mc *mysqlConn) readAuthResult() ([]byte, string, error) { | |||
data, err := mc.readPacket() | |||
if err == nil { | |||
// packet indicator | |||
switch data[0] { | |||
if err != nil { | |||
return nil, "", err | |||
} | |||
case iOK: | |||
return nil, mc.handleOkPacket(data) | |||
// packet indicator | |||
switch data[0] { | |||
case iEOF: | |||
if len(data) > 1 { | |||
pluginEndIndex := bytes.IndexByte(data, 0x00) | |||
plugin := string(data[1:pluginEndIndex]) | |||
cipher := data[pluginEndIndex+1 : len(data)-1] | |||
if plugin == "mysql_old_password" { | |||
// using old_passwords | |||
return cipher, ErrOldPassword | |||
} else if plugin == "mysql_clear_password" { | |||
// using clear text password | |||
return cipher, ErrCleartextPassword | |||
} else if plugin == "mysql_native_password" { | |||
// using mysql default authentication method | |||
return cipher, ErrNativePassword | |||
} else { | |||
return cipher, ErrUnknownPlugin | |||
} | |||
} else { | |||
return nil, ErrOldPassword | |||
} | |||
case iOK: | |||
return nil, "", mc.handleOkPacket(data) | |||
default: // Error otherwise | |||
return nil, mc.handleErrorPacket(data) | |||
case iAuthMoreData: | |||
return data[1:], "", err | |||
case iEOF: | |||
if len(data) < 1 { | |||
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest | |||
return nil, "mysql_old_password", nil | |||
} | |||
pluginEndIndex := bytes.IndexByte(data, 0x00) | |||
if pluginEndIndex < 0 { | |||
return nil, "", ErrMalformPkt | |||
} | |||
plugin := string(data[1:pluginEndIndex]) | |||
authData := data[pluginEndIndex+1:] | |||
return authData, plugin, nil | |||
default: // Error otherwise | |||
return nil, "", mc.handleErrorPacket(data) | |||
} | |||
return nil, err | |||
} | |||
// Returns error if Packet is not an 'Result OK'-Packet | |||
func (mc *mysqlConn) readResultOK() error { | |||
data, err := mc.readPacket() | |||
if err != nil { | |||
return err | |||
} | |||
if data[0] == iOK { | |||
return mc.handleOkPacket(data) | |||
} | |||
return mc.handleErrorPacket(data) | |||
} | |||
// Result Set Header Packet | |||
@@ -543,6 +552,22 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error { | |||
// Error Number [16 bit uint] | |||
errno := binary.LittleEndian.Uint16(data[1:3]) | |||
// 1792: ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION | |||
// 1290: ER_OPTION_PREVENTS_STATEMENT (returned by Aurora during failover) | |||
if (errno == 1792 || errno == 1290) && mc.cfg.RejectReadOnly { | |||
// Oops; we are connected to a read-only connection, and won't be able | |||
// to issue any write statements. Since RejectReadOnly is configured, | |||
// we throw away this connection hoping this one would have write | |||
// permission. This is specifically for a possible race condition | |||
// during failover (e.g. on AWS Aurora). See README.md for more. | |||
// | |||
// We explicitly close the connection before returning | |||
// driver.ErrBadConn to ensure that `database/sql` purges this | |||
// connection and initiates a new one for next statement next time. | |||
mc.Close() | |||
return driver.ErrBadConn | |||
} | |||
pos := 3 | |||
// SQL State [optional: # + 5bytes string] | |||
@@ -577,19 +602,12 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error { | |||
// server_status [2 bytes] | |||
mc.status = readStatus(data[1+n+m : 1+n+m+2]) | |||
if err := mc.discardResults(); err != nil { | |||
return err | |||
if mc.status&statusMoreResultsExists != 0 { | |||
return nil | |||
} | |||
// warning count [2 bytes] | |||
if !mc.strict { | |||
return nil | |||
} | |||
pos := 1 + n + m + 2 | |||
if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 { | |||
return mc.getWarnings() | |||
} | |||
return nil | |||
} | |||
@@ -661,14 +679,21 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { | |||
if err != nil { | |||
return nil, err | |||
} | |||
pos += n | |||
// Filler [uint8] | |||
pos++ | |||
// Charset [charset, collation uint8] | |||
columns[i].charSet = data[pos] | |||
pos += 2 | |||
// Length [uint32] | |||
pos += n + 1 + 2 + 4 | |||
columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4]) | |||
pos += 4 | |||
// Field type [uint8] | |||
columns[i].fieldType = data[pos] | |||
columns[i].fieldType = fieldType(data[pos]) | |||
pos++ | |||
// Flags [uint16] | |||
@@ -691,6 +716,10 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { | |||
func (rows *textRows) readRow(dest []driver.Value) error { | |||
mc := rows.mc | |||
if rows.rs.done { | |||
return io.EOF | |||
} | |||
data, err := mc.readPacket() | |||
if err != nil { | |||
return err | |||
@@ -700,10 +729,10 @@ func (rows *textRows) readRow(dest []driver.Value) error { | |||
if data[0] == iEOF && len(data) == 5 { | |||
// server_status [2 bytes] | |||
rows.mc.status = readStatus(data[3:]) | |||
if err := rows.mc.discardResults(); err != nil { | |||
return err | |||
rows.rs.done = true | |||
if !rows.HasNextResultSet() { | |||
rows.mc = nil | |||
} | |||
rows.mc = nil | |||
return io.EOF | |||
} | |||
if data[0] == iERR { | |||
@@ -725,7 +754,7 @@ func (rows *textRows) readRow(dest []driver.Value) error { | |||
if !mc.parseTime { | |||
continue | |||
} else { | |||
switch rows.columns[i].fieldType { | |||
switch rows.rs.columns[i].fieldType { | |||
case fieldTypeTimestamp, fieldTypeDateTime, | |||
fieldTypeDate, fieldTypeNewDate: | |||
dest[i], err = parseDateTime( | |||
@@ -797,14 +826,7 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { | |||
// Reserved [8 bit] | |||
// Warning count [16 bit uint] | |||
if !stmt.mc.strict { | |||
return columnCount, nil | |||
} | |||
// Check for warnings count > 0, only available in MySQL > 4.1 | |||
if len(data) >= 12 && binary.LittleEndian.Uint16(data[10:12]) > 0 { | |||
return columnCount, stmt.mc.getWarnings() | |||
} | |||
return columnCount, nil | |||
} | |||
return 0, err | |||
@@ -821,7 +843,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { | |||
// 2 bytes paramID | |||
const dataOffset = 1 + 4 + 2 | |||
// Can not use the write buffer since | |||
// Cannot use the write buffer since | |||
// a) the buffer is too small | |||
// b) it is in use | |||
data := make([]byte, 4+1+4+2+len(arg)) | |||
@@ -876,6 +898,12 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | |||
const minPktLen = 4 + 1 + 4 + 1 + 4 | |||
mc := stmt.mc | |||
// Determine threshould dynamically to avoid packet size shortage. | |||
longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1) | |||
if longDataSize < 64 { | |||
longDataSize = 64 | |||
} | |||
// Reset packet-sequence | |||
mc.sequence = 0 | |||
@@ -887,9 +915,9 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | |||
data = mc.buf.takeCompleteBuffer() | |||
} | |||
if data == nil { | |||
// can not take the buffer. Something must be wrong with the connection | |||
// cannot take the buffer. Something must be wrong with the connection | |||
errLog.Print(ErrBusyBuffer) | |||
return driver.ErrBadConn | |||
return errBadConnNoWrite | |||
} | |||
// command [1 byte] | |||
@@ -948,7 +976,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | |||
// build NULL-bitmap | |||
if arg == nil { | |||
nullMask[i/8] |= 1 << (uint(i) & 7) | |||
paramTypes[i+i] = fieldTypeNULL | |||
paramTypes[i+i] = byte(fieldTypeNULL) | |||
paramTypes[i+i+1] = 0x00 | |||
continue | |||
} | |||
@@ -956,7 +984,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | |||
// cache types and values | |||
switch v := arg.(type) { | |||
case int64: | |||
paramTypes[i+i] = fieldTypeLongLong | |||
paramTypes[i+i] = byte(fieldTypeLongLong) | |||
paramTypes[i+i+1] = 0x00 | |||
if cap(paramValues)-len(paramValues)-8 >= 0 { | |||
@@ -972,7 +1000,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | |||
} | |||
case float64: | |||
paramTypes[i+i] = fieldTypeDouble | |||
paramTypes[i+i] = byte(fieldTypeDouble) | |||
paramTypes[i+i+1] = 0x00 | |||
if cap(paramValues)-len(paramValues)-8 >= 0 { | |||
@@ -988,7 +1016,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | |||
} | |||
case bool: | |||
paramTypes[i+i] = fieldTypeTiny | |||
paramTypes[i+i] = byte(fieldTypeTiny) | |||
paramTypes[i+i+1] = 0x00 | |||
if v { | |||
@@ -1000,10 +1028,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | |||
case []byte: | |||
// Common case (non-nil value) first | |||
if v != nil { | |||
paramTypes[i+i] = fieldTypeString | |||
paramTypes[i+i] = byte(fieldTypeString) | |||
paramTypes[i+i+1] = 0x00 | |||
if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 { | |||
if len(v) < longDataSize { | |||
paramValues = appendLengthEncodedInteger(paramValues, | |||
uint64(len(v)), | |||
) | |||
@@ -1018,14 +1046,14 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | |||
// Handle []byte(nil) as a NULL value | |||
nullMask[i/8] |= 1 << (uint(i) & 7) | |||
paramTypes[i+i] = fieldTypeNULL | |||
paramTypes[i+i] = byte(fieldTypeNULL) | |||
paramTypes[i+i+1] = 0x00 | |||
case string: | |||
paramTypes[i+i] = fieldTypeString | |||
paramTypes[i+i] = byte(fieldTypeString) | |||
paramTypes[i+i+1] = 0x00 | |||
if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 { | |||
if len(v) < longDataSize { | |||
paramValues = appendLengthEncodedInteger(paramValues, | |||
uint64(len(v)), | |||
) | |||
@@ -1037,23 +1065,25 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | |||
} | |||
case time.Time: | |||
paramTypes[i+i] = fieldTypeString | |||
paramTypes[i+i] = byte(fieldTypeString) | |||
paramTypes[i+i+1] = 0x00 | |||
var val []byte | |||
var a [64]byte | |||
var b = a[:0] | |||
if v.IsZero() { | |||
val = []byte("0000-00-00") | |||
b = append(b, "0000-00-00"...) | |||
} else { | |||
val = []byte(v.In(mc.cfg.Loc).Format(timeFormat)) | |||
b = v.In(mc.cfg.Loc).AppendFormat(b, timeFormat) | |||
} | |||
paramValues = appendLengthEncodedInteger(paramValues, | |||
uint64(len(val)), | |||
uint64(len(b)), | |||
) | |||
paramValues = append(paramValues, val...) | |||
paramValues = append(paramValues, b...) | |||
default: | |||
return fmt.Errorf("can not convert type: %T", arg) | |||
return fmt.Errorf("cannot convert type: %T", arg) | |||
} | |||
} | |||
@@ -1086,8 +1116,6 @@ func (mc *mysqlConn) discardResults() error { | |||
if err := mc.readUntilEOF(); err != nil { | |||
return err | |||
} | |||
} else { | |||
mc.status &^= statusMoreResultsExists | |||
} | |||
} | |||
return nil | |||
@@ -1105,16 +1133,17 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | |||
// EOF Packet | |||
if data[0] == iEOF && len(data) == 5 { | |||
rows.mc.status = readStatus(data[3:]) | |||
if err := rows.mc.discardResults(); err != nil { | |||
return err | |||
rows.rs.done = true | |||
if !rows.HasNextResultSet() { | |||
rows.mc = nil | |||
} | |||
rows.mc = nil | |||
return io.EOF | |||
} | |||
mc := rows.mc | |||
rows.mc = nil | |||
// Error otherwise | |||
return rows.mc.handleErrorPacket(data) | |||
return mc.handleErrorPacket(data) | |||
} | |||
// NULL-bitmap, [(column-count + 7 + 2) / 8 bytes] | |||
@@ -1130,14 +1159,14 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | |||
} | |||
// Convert to byte-coded string | |||
switch rows.columns[i].fieldType { | |||
switch rows.rs.columns[i].fieldType { | |||
case fieldTypeNULL: | |||
dest[i] = nil | |||
continue | |||
// Numeric Types | |||
case fieldTypeTiny: | |||
if rows.columns[i].flags&flagUnsigned != 0 { | |||
if rows.rs.columns[i].flags&flagUnsigned != 0 { | |||
dest[i] = int64(data[pos]) | |||
} else { | |||
dest[i] = int64(int8(data[pos])) | |||
@@ -1146,7 +1175,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | |||
continue | |||
case fieldTypeShort, fieldTypeYear: | |||
if rows.columns[i].flags&flagUnsigned != 0 { | |||
if rows.rs.columns[i].flags&flagUnsigned != 0 { | |||
dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2])) | |||
} else { | |||
dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2]))) | |||
@@ -1155,7 +1184,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | |||
continue | |||
case fieldTypeInt24, fieldTypeLong: | |||
if rows.columns[i].flags&flagUnsigned != 0 { | |||
if rows.rs.columns[i].flags&flagUnsigned != 0 { | |||
dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4])) | |||
} else { | |||
dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4]))) | |||
@@ -1164,7 +1193,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | |||
continue | |||
case fieldTypeLongLong: | |||
if rows.columns[i].flags&flagUnsigned != 0 { | |||
if rows.rs.columns[i].flags&flagUnsigned != 0 { | |||
val := binary.LittleEndian.Uint64(data[pos : pos+8]) | |||
if val > math.MaxInt64 { | |||
dest[i] = uint64ToString(val) | |||
@@ -1178,7 +1207,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | |||
continue | |||
case fieldTypeFloat: | |||
dest[i] = float32(math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4]))) | |||
dest[i] = math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4])) | |||
pos += 4 | |||
continue | |||
@@ -1218,10 +1247,10 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | |||
case isNull: | |||
dest[i] = nil | |||
continue | |||
case rows.columns[i].fieldType == fieldTypeTime: | |||
case rows.rs.columns[i].fieldType == fieldTypeTime: | |||
// database/sql does not support an equivalent to TIME, return a string | |||
var dstlen uint8 | |||
switch decimals := rows.columns[i].decimals; decimals { | |||
switch decimals := rows.rs.columns[i].decimals; decimals { | |||
case 0x00, 0x1f: | |||
dstlen = 8 | |||
case 1, 2, 3, 4, 5, 6: | |||
@@ -1229,7 +1258,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | |||
default: | |||
return fmt.Errorf( | |||
"protocol error, illegal decimals value %d", | |||
rows.columns[i].decimals, | |||
rows.rs.columns[i].decimals, | |||
) | |||
} | |||
dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true) | |||
@@ -1237,10 +1266,10 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | |||
dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc) | |||
default: | |||
var dstlen uint8 | |||
if rows.columns[i].fieldType == fieldTypeDate { | |||
if rows.rs.columns[i].fieldType == fieldTypeDate { | |||
dstlen = 10 | |||
} else { | |||
switch decimals := rows.columns[i].decimals; decimals { | |||
switch decimals := rows.rs.columns[i].decimals; decimals { | |||
case 0x00, 0x1f: | |||
dstlen = 19 | |||
case 1, 2, 3, 4, 5, 6: | |||
@@ -1248,7 +1277,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | |||
default: | |||
return fmt.Errorf( | |||
"protocol error, illegal decimals value %d", | |||
rows.columns[i].decimals, | |||
rows.rs.columns[i].decimals, | |||
) | |||
} | |||
} | |||
@@ -1264,7 +1293,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | |||
// Please report if this happens! | |||
default: | |||
return fmt.Errorf("unknown field type %d", rows.columns[i].fieldType) | |||
return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType) | |||
} | |||
} | |||
@@ -11,19 +11,20 @@ package mysql | |||
import ( | |||
"database/sql/driver" | |||
"io" | |||
"math" | |||
"reflect" | |||
) | |||
type mysqlField struct { | |||
tableName string | |||
name string | |||
flags fieldFlag | |||
fieldType byte | |||
decimals byte | |||
type resultSet struct { | |||
columns []mysqlField | |||
columnNames []string | |||
done bool | |||
} | |||
type mysqlRows struct { | |||
mc *mysqlConn | |||
columns []mysqlField | |||
mc *mysqlConn | |||
rs resultSet | |||
finish func() | |||
} | |||
type binaryRows struct { | |||
@@ -34,37 +35,86 @@ type textRows struct { | |||
mysqlRows | |||
} | |||
type emptyRows struct{} | |||
func (rows *mysqlRows) Columns() []string { | |||
columns := make([]string, len(rows.columns)) | |||
if rows.rs.columnNames != nil { | |||
return rows.rs.columnNames | |||
} | |||
columns := make([]string, len(rows.rs.columns)) | |||
if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias { | |||
for i := range columns { | |||
if tableName := rows.columns[i].tableName; len(tableName) > 0 { | |||
columns[i] = tableName + "." + rows.columns[i].name | |||
if tableName := rows.rs.columns[i].tableName; len(tableName) > 0 { | |||
columns[i] = tableName + "." + rows.rs.columns[i].name | |||
} else { | |||
columns[i] = rows.columns[i].name | |||
columns[i] = rows.rs.columns[i].name | |||
} | |||
} | |||
} else { | |||
for i := range columns { | |||
columns[i] = rows.columns[i].name | |||
columns[i] = rows.rs.columns[i].name | |||
} | |||
} | |||
rows.rs.columnNames = columns | |||
return columns | |||
} | |||
func (rows *mysqlRows) Close() error { | |||
func (rows *mysqlRows) ColumnTypeDatabaseTypeName(i int) string { | |||
return rows.rs.columns[i].typeDatabaseName() | |||
} | |||
// func (rows *mysqlRows) ColumnTypeLength(i int) (length int64, ok bool) { | |||
// return int64(rows.rs.columns[i].length), true | |||
// } | |||
func (rows *mysqlRows) ColumnTypeNullable(i int) (nullable, ok bool) { | |||
return rows.rs.columns[i].flags&flagNotNULL == 0, true | |||
} | |||
func (rows *mysqlRows) ColumnTypePrecisionScale(i int) (int64, int64, bool) { | |||
column := rows.rs.columns[i] | |||
decimals := int64(column.decimals) | |||
switch column.fieldType { | |||
case fieldTypeDecimal, fieldTypeNewDecimal: | |||
if decimals > 0 { | |||
return int64(column.length) - 2, decimals, true | |||
} | |||
return int64(column.length) - 1, decimals, true | |||
case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeTime: | |||
return decimals, decimals, true | |||
case fieldTypeFloat, fieldTypeDouble: | |||
if decimals == 0x1f { | |||
return math.MaxInt64, math.MaxInt64, true | |||
} | |||
return math.MaxInt64, decimals, true | |||
} | |||
return 0, 0, false | |||
} | |||
func (rows *mysqlRows) ColumnTypeScanType(i int) reflect.Type { | |||
return rows.rs.columns[i].scanType() | |||
} | |||
func (rows *mysqlRows) Close() (err error) { | |||
if f := rows.finish; f != nil { | |||
f() | |||
rows.finish = nil | |||
} | |||
mc := rows.mc | |||
if mc == nil { | |||
return nil | |||
} | |||
if mc.netConn == nil { | |||
return ErrInvalidConn | |||
if err := mc.error(); err != nil { | |||
return err | |||
} | |||
// Remove unread packets from stream | |||
err := mc.readUntilEOF() | |||
if !rows.rs.done { | |||
err = mc.readUntilEOF() | |||
} | |||
if err == nil { | |||
if err = mc.discardResults(); err != nil { | |||
return err | |||
@@ -75,22 +125,66 @@ func (rows *mysqlRows) Close() error { | |||
return err | |||
} | |||
func (rows *binaryRows) Next(dest []driver.Value) error { | |||
if mc := rows.mc; mc != nil { | |||
if mc.netConn == nil { | |||
return ErrInvalidConn | |||
func (rows *mysqlRows) HasNextResultSet() (b bool) { | |||
if rows.mc == nil { | |||
return false | |||
} | |||
return rows.mc.status&statusMoreResultsExists != 0 | |||
} | |||
func (rows *mysqlRows) nextResultSet() (int, error) { | |||
if rows.mc == nil { | |||
return 0, io.EOF | |||
} | |||
if err := rows.mc.error(); err != nil { | |||
return 0, err | |||
} | |||
// Remove unread packets from stream | |||
if !rows.rs.done { | |||
if err := rows.mc.readUntilEOF(); err != nil { | |||
return 0, err | |||
} | |||
rows.rs.done = true | |||
} | |||
// Fetch next row from stream | |||
return rows.readRow(dest) | |||
if !rows.HasNextResultSet() { | |||
rows.mc = nil | |||
return 0, io.EOF | |||
} | |||
return io.EOF | |||
rows.rs = resultSet{} | |||
return rows.mc.readResultSetHeaderPacket() | |||
} | |||
func (rows *textRows) Next(dest []driver.Value) error { | |||
func (rows *mysqlRows) nextNotEmptyResultSet() (int, error) { | |||
for { | |||
resLen, err := rows.nextResultSet() | |||
if err != nil { | |||
return 0, err | |||
} | |||
if resLen > 0 { | |||
return resLen, nil | |||
} | |||
rows.rs.done = true | |||
} | |||
} | |||
func (rows *binaryRows) NextResultSet() error { | |||
resLen, err := rows.nextNotEmptyResultSet() | |||
if err != nil { | |||
return err | |||
} | |||
rows.rs.columns, err = rows.mc.readColumns(resLen) | |||
return err | |||
} | |||
func (rows *binaryRows) Next(dest []driver.Value) error { | |||
if mc := rows.mc; mc != nil { | |||
if mc.netConn == nil { | |||
return ErrInvalidConn | |||
if err := mc.error(); err != nil { | |||
return err | |||
} | |||
// Fetch next row from stream | |||
@@ -99,14 +193,24 @@ func (rows *textRows) Next(dest []driver.Value) error { | |||
return io.EOF | |||
} | |||
func (rows emptyRows) Columns() []string { | |||
return nil | |||
} | |||
func (rows *textRows) NextResultSet() (err error) { | |||
resLen, err := rows.nextNotEmptyResultSet() | |||
if err != nil { | |||
return err | |||
} | |||
func (rows emptyRows) Close() error { | |||
return nil | |||
rows.rs.columns, err = rows.mc.readColumns(resLen) | |||
return err | |||
} | |||
func (rows emptyRows) Next(dest []driver.Value) error { | |||
func (rows *textRows) Next(dest []driver.Value) error { | |||
if mc := rows.mc; mc != nil { | |||
if err := mc.error(); err != nil { | |||
return err | |||
} | |||
// Fetch next row from stream | |||
return rows.readRow(dest) | |||
} | |||
return io.EOF | |||
} |
@@ -11,6 +11,7 @@ package mysql | |||
import ( | |||
"database/sql/driver" | |||
"fmt" | |||
"io" | |||
"reflect" | |||
"strconv" | |||
) | |||
@@ -19,12 +20,14 @@ type mysqlStmt struct { | |||
mc *mysqlConn | |||
id uint32 | |||
paramCount int | |||
columns []mysqlField // cached from the first query | |||
} | |||
func (stmt *mysqlStmt) Close() error { | |||
if stmt.mc == nil || stmt.mc.netConn == nil { | |||
errLog.Print(ErrInvalidConn) | |||
if stmt.mc == nil || stmt.mc.closed.IsSet() { | |||
// driver.Stmt.Close can be called more than once, thus this function | |||
// has to be idempotent. | |||
// See also Issue #450 and golang/go#16019. | |||
//errLog.Print(ErrInvalidConn) | |||
return driver.ErrBadConn | |||
} | |||
@@ -42,14 +45,14 @@ func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter { | |||
} | |||
func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { | |||
if stmt.mc.netConn == nil { | |||
if stmt.mc.closed.IsSet() { | |||
errLog.Print(ErrInvalidConn) | |||
return nil, driver.ErrBadConn | |||
} | |||
// Send command | |||
err := stmt.writeExecutePacket(args) | |||
if err != nil { | |||
return nil, err | |||
return nil, stmt.mc.markBadConn(err) | |||
} | |||
mc := stmt.mc | |||
@@ -59,37 +62,45 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { | |||
// Read Result | |||
resLen, err := mc.readResultSetHeaderPacket() | |||
if err == nil { | |||
if resLen > 0 { | |||
// Columns | |||
err = mc.readUntilEOF() | |||
if err != nil { | |||
return nil, err | |||
} | |||
// Rows | |||
err = mc.readUntilEOF() | |||
if err != nil { | |||
return nil, err | |||
} | |||
if resLen > 0 { | |||
// Columns | |||
if err = mc.readUntilEOF(); err != nil { | |||
return nil, err | |||
} | |||
if err == nil { | |||
return &mysqlResult{ | |||
affectedRows: int64(mc.affectedRows), | |||
insertId: int64(mc.insertId), | |||
}, nil | |||
// Rows | |||
if err := mc.readUntilEOF(); err != nil { | |||
return nil, err | |||
} | |||
} | |||
return nil, err | |||
if err := mc.discardResults(); err != nil { | |||
return nil, err | |||
} | |||
return &mysqlResult{ | |||
affectedRows: int64(mc.affectedRows), | |||
insertId: int64(mc.insertId), | |||
}, nil | |||
} | |||
func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { | |||
if stmt.mc.netConn == nil { | |||
return stmt.query(args) | |||
} | |||
func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { | |||
if stmt.mc.closed.IsSet() { | |||
errLog.Print(ErrInvalidConn) | |||
return nil, driver.ErrBadConn | |||
} | |||
// Send command | |||
err := stmt.writeExecutePacket(args) | |||
if err != nil { | |||
return nil, err | |||
return nil, stmt.mc.markBadConn(err) | |||
} | |||
mc := stmt.mc | |||
@@ -104,14 +115,15 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { | |||
if resLen > 0 { | |||
rows.mc = mc | |||
// Columns | |||
// If not cached, read them and cache them | |||
if stmt.columns == nil { | |||
rows.columns, err = mc.readColumns(resLen) | |||
stmt.columns = rows.columns | |||
} else { | |||
rows.columns = stmt.columns | |||
err = mc.readUntilEOF() | |||
rows.rs.columns, err = mc.readColumns(resLen) | |||
} else { | |||
rows.rs.done = true | |||
switch err := rows.NextResultSet(); err { | |||
case nil, io.EOF: | |||
return rows, nil | |||
default: | |||
return nil, err | |||
} | |||
} | |||
@@ -120,19 +132,36 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { | |||
type converter struct{} | |||
// ConvertValue mirrors the reference/default converter in database/sql/driver | |||
// with _one_ exception. We support uint64 with their high bit and the default | |||
// implementation does not. This function should be kept in sync with | |||
// database/sql/driver defaultConverter.ConvertValue() except for that | |||
// deliberate difference. | |||
func (c converter) ConvertValue(v interface{}) (driver.Value, error) { | |||
if driver.IsValue(v) { | |||
return v, nil | |||
} | |||
if vr, ok := v.(driver.Valuer); ok { | |||
sv, err := callValuerValue(vr) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if !driver.IsValue(sv) { | |||
return nil, fmt.Errorf("non-Value type %T returned from Value", sv) | |||
} | |||
return sv, nil | |||
} | |||
rv := reflect.ValueOf(v) | |||
switch rv.Kind() { | |||
case reflect.Ptr: | |||
// indirect pointers | |||
if rv.IsNil() { | |||
return nil, nil | |||
} else { | |||
return c.ConvertValue(rv.Elem().Interface()) | |||
} | |||
return c.ConvertValue(rv.Elem().Interface()) | |||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | |||
return rv.Int(), nil | |||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: | |||
@@ -145,6 +174,38 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) { | |||
return int64(u64), nil | |||
case reflect.Float32, reflect.Float64: | |||
return rv.Float(), nil | |||
case reflect.Bool: | |||
return rv.Bool(), nil | |||
case reflect.Slice: | |||
ek := rv.Type().Elem().Kind() | |||
if ek == reflect.Uint8 { | |||
return rv.Bytes(), nil | |||
} | |||
return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek) | |||
case reflect.String: | |||
return rv.String(), nil | |||
} | |||
return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind()) | |||
} | |||
var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() | |||
// callValuerValue returns vr.Value(), with one exception: | |||
// If vr.Value is an auto-generated method on a pointer type and the | |||
// pointer is nil, it would panic at runtime in the panicwrap | |||
// method. Treat it like nil instead. | |||
// | |||
// This is so people can implement driver.Value on value types and | |||
// still use nil pointers to those types to mean nil/NULL, just like | |||
// string/*string. | |||
// | |||
// This is an exact copy of the same-named unexported function from the | |||
// database/sql package. | |||
func callValuerValue(vr driver.Valuer) (v driver.Value, err error) { | |||
if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr && | |||
rv.IsNil() && | |||
rv.Type().Elem().Implements(valuerReflectType) { | |||
return nil, nil | |||
} | |||
return vr.Value() | |||
} |
@@ -13,7 +13,7 @@ type mysqlTx struct { | |||
} | |||
func (tx *mysqlTx) Commit() (err error) { | |||
if tx.mc == nil || tx.mc.netConn == nil { | |||
if tx.mc == nil || tx.mc.closed.IsSet() { | |||
return ErrInvalidConn | |||
} | |||
err = tx.mc.exec("COMMIT") | |||
@@ -22,7 +22,7 @@ func (tx *mysqlTx) Commit() (err error) { | |||
} | |||
func (tx *mysqlTx) Rollback() (err error) { | |||
if tx.mc == nil || tx.mc.netConn == nil { | |||
if tx.mc == nil || tx.mc.closed.IsSet() { | |||
return ErrInvalidConn | |||
} | |||
err = tx.mc.exec("ROLLBACK") | |||
@@ -9,23 +9,29 @@ | |||
package mysql | |||
import ( | |||
"crypto/sha1" | |||
"crypto/tls" | |||
"database/sql/driver" | |||
"encoding/binary" | |||
"fmt" | |||
"io" | |||
"strings" | |||
"sync" | |||
"sync/atomic" | |||
"time" | |||
) | |||
// Registry for custom tls.Configs | |||
var ( | |||
tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs | |||
tlsConfigLock sync.RWMutex | |||
tlsConfigRegistry map[string]*tls.Config | |||
) | |||
// RegisterTLSConfig registers a custom tls.Config to be used with sql.Open. | |||
// Use the key as a value in the DSN where tls=value. | |||
// | |||
// Note: The provided tls.Config is exclusively owned by the driver after | |||
// registering it. | |||
// | |||
// rootCertPool := x509.NewCertPool() | |||
// pem, err := ioutil.ReadFile("/path/ca-cert.pem") | |||
// if err != nil { | |||
@@ -51,19 +57,32 @@ func RegisterTLSConfig(key string, config *tls.Config) error { | |||
return fmt.Errorf("key '%s' is reserved", key) | |||
} | |||
if tlsConfigRegister == nil { | |||
tlsConfigRegister = make(map[string]*tls.Config) | |||
tlsConfigLock.Lock() | |||
if tlsConfigRegistry == nil { | |||
tlsConfigRegistry = make(map[string]*tls.Config) | |||
} | |||
tlsConfigRegister[key] = config | |||
tlsConfigRegistry[key] = config | |||
tlsConfigLock.Unlock() | |||
return nil | |||
} | |||
// DeregisterTLSConfig removes the tls.Config associated with key. | |||
func DeregisterTLSConfig(key string) { | |||
if tlsConfigRegister != nil { | |||
delete(tlsConfigRegister, key) | |||
tlsConfigLock.Lock() | |||
if tlsConfigRegistry != nil { | |||
delete(tlsConfigRegistry, key) | |||
} | |||
tlsConfigLock.Unlock() | |||
} | |||
func getTLSConfigClone(key string) (config *tls.Config) { | |||
tlsConfigLock.RLock() | |||
if v, ok := tlsConfigRegistry[key]; ok { | |||
config = cloneTLSConfig(v) | |||
} | |||
tlsConfigLock.RUnlock() | |||
return | |||
} | |||
// Returns the bool value of the input. | |||
@@ -81,119 +100,6 @@ func readBool(input string) (value bool, valid bool) { | |||
} | |||
/****************************************************************************** | |||
* Authentication * | |||
******************************************************************************/ | |||
// Encrypt password using 4.1+ method | |||
func scramblePassword(scramble, password []byte) []byte { | |||
if len(password) == 0 { | |||
return nil | |||
} | |||
// stage1Hash = SHA1(password) | |||
crypt := sha1.New() | |||
crypt.Write(password) | |||
stage1 := crypt.Sum(nil) | |||
// scrambleHash = SHA1(scramble + SHA1(stage1Hash)) | |||
// inner Hash | |||
crypt.Reset() | |||
crypt.Write(stage1) | |||
hash := crypt.Sum(nil) | |||
// outer Hash | |||
crypt.Reset() | |||
crypt.Write(scramble) | |||
crypt.Write(hash) | |||
scramble = crypt.Sum(nil) | |||
// token = scrambleHash XOR stage1Hash | |||
for i := range scramble { | |||
scramble[i] ^= stage1[i] | |||
} | |||
return scramble | |||
} | |||
// Encrypt password using pre 4.1 (old password) method | |||
// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c | |||
type myRnd struct { | |||
seed1, seed2 uint32 | |||
} | |||
const myRndMaxVal = 0x3FFFFFFF | |||
// Pseudo random number generator | |||
func newMyRnd(seed1, seed2 uint32) *myRnd { | |||
return &myRnd{ | |||
seed1: seed1 % myRndMaxVal, | |||
seed2: seed2 % myRndMaxVal, | |||
} | |||
} | |||
// Tested to be equivalent to MariaDB's floating point variant | |||
// http://play.golang.org/p/QHvhd4qved | |||
// http://play.golang.org/p/RG0q4ElWDx | |||
func (r *myRnd) NextByte() byte { | |||
r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal | |||
r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal | |||
return byte(uint64(r.seed1) * 31 / myRndMaxVal) | |||
} | |||
// Generate binary hash from byte string using insecure pre 4.1 method | |||
func pwHash(password []byte) (result [2]uint32) { | |||
var add uint32 = 7 | |||
var tmp uint32 | |||
result[0] = 1345345333 | |||
result[1] = 0x12345671 | |||
for _, c := range password { | |||
// skip spaces and tabs in password | |||
if c == ' ' || c == '\t' { | |||
continue | |||
} | |||
tmp = uint32(c) | |||
result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8) | |||
result[1] += (result[1] << 8) ^ result[0] | |||
add += tmp | |||
} | |||
// Remove sign bit (1<<31)-1) | |||
result[0] &= 0x7FFFFFFF | |||
result[1] &= 0x7FFFFFFF | |||
return | |||
} | |||
// Encrypt password using insecure pre 4.1 method | |||
func scrambleOldPassword(scramble, password []byte) []byte { | |||
if len(password) == 0 { | |||
return nil | |||
} | |||
scramble = scramble[:8] | |||
hashPw := pwHash(password) | |||
hashSc := pwHash(scramble) | |||
r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1]) | |||
var out [8]byte | |||
for i := range out { | |||
out[i] = r.NextByte() + 64 | |||
} | |||
mask := r.NextByte() | |||
for i := range out { | |||
out[i] ^= mask | |||
} | |||
return out[:] | |||
} | |||
/****************************************************************************** | |||
* Time related utils * | |||
******************************************************************************/ | |||
@@ -519,7 +425,7 @@ func readLengthEncodedString(b []byte) ([]byte, bool, int, error) { | |||
// Check data length | |||
if len(b) >= n { | |||
return b[n-int(num) : n], false, n, nil | |||
return b[n-int(num) : n : n], false, n, nil | |||
} | |||
return nil, false, n, io.EOF | |||
} | |||
@@ -548,8 +454,8 @@ func readLengthEncodedInteger(b []byte) (uint64, bool, int) { | |||
if len(b) == 0 { | |||
return 0, true, 1 | |||
} | |||
switch b[0] { | |||
switch b[0] { | |||
// 251: NULL | |||
case 0xfb: | |||
return 0, true, 1 | |||
@@ -738,3 +644,67 @@ func escapeStringQuotes(buf []byte, v string) []byte { | |||
return buf[:pos] | |||
} | |||
/****************************************************************************** | |||
* Sync utils * | |||
******************************************************************************/ | |||
// noCopy may be embedded into structs which must not be copied | |||
// after the first use. | |||
// | |||
// See https://github.com/golang/go/issues/8005#issuecomment-190753527 | |||
// for details. | |||
type noCopy struct{} | |||
// Lock is a no-op used by -copylocks checker from `go vet`. | |||
func (*noCopy) Lock() {} | |||
// atomicBool is a wrapper around uint32 for usage as a boolean value with | |||
// atomic access. | |||
type atomicBool struct { | |||
_noCopy noCopy | |||
value uint32 | |||
} | |||
// IsSet returns wether the current boolean value is true | |||
func (ab *atomicBool) IsSet() bool { | |||
return atomic.LoadUint32(&ab.value) > 0 | |||
} | |||
// Set sets the value of the bool regardless of the previous value | |||
func (ab *atomicBool) Set(value bool) { | |||
if value { | |||
atomic.StoreUint32(&ab.value, 1) | |||
} else { | |||
atomic.StoreUint32(&ab.value, 0) | |||
} | |||
} | |||
// TrySet sets the value of the bool and returns wether the value changed | |||
func (ab *atomicBool) TrySet(value bool) bool { | |||
if value { | |||
return atomic.SwapUint32(&ab.value, 1) == 0 | |||
} | |||
return atomic.SwapUint32(&ab.value, 0) > 0 | |||
} | |||
// atomicError is a wrapper for atomically accessed error values | |||
type atomicError struct { | |||
_noCopy noCopy | |||
value atomic.Value | |||
} | |||
// Set sets the error value regardless of the previous value. | |||
// The value must not be nil | |||
func (ae *atomicError) Set(value error) { | |||
ae.value.Store(value) | |||
} | |||
// Value returns the current error value | |||
func (ae *atomicError) Value() error { | |||
if v := ae.value.Load(); v != nil { | |||
// this will panic if the value doesn't implement the error interface | |||
return v.(error) | |||
} | |||
return nil | |||
} |
@@ -0,0 +1,40 @@ | |||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package | |||
// | |||
// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. | |||
// | |||
// This Source Code Form is subject to the terms of the Mozilla Public | |||
// License, v. 2.0. If a copy of the MPL was not distributed with this file, | |||
// You can obtain one at http://mozilla.org/MPL/2.0/. | |||
// +build go1.7 | |||
// +build !go1.8 | |||
package mysql | |||
import "crypto/tls" | |||
func cloneTLSConfig(c *tls.Config) *tls.Config { | |||
return &tls.Config{ | |||
Rand: c.Rand, | |||
Time: c.Time, | |||
Certificates: c.Certificates, | |||
NameToCertificate: c.NameToCertificate, | |||
GetCertificate: c.GetCertificate, | |||
RootCAs: c.RootCAs, | |||
NextProtos: c.NextProtos, | |||
ServerName: c.ServerName, | |||
ClientAuth: c.ClientAuth, | |||
ClientCAs: c.ClientCAs, | |||
InsecureSkipVerify: c.InsecureSkipVerify, | |||
CipherSuites: c.CipherSuites, | |||
PreferServerCipherSuites: c.PreferServerCipherSuites, | |||
SessionTicketsDisabled: c.SessionTicketsDisabled, | |||
SessionTicketKey: c.SessionTicketKey, | |||
ClientSessionCache: c.ClientSessionCache, | |||
MinVersion: c.MinVersion, | |||
MaxVersion: c.MaxVersion, | |||
CurvePreferences: c.CurvePreferences, | |||
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, | |||
Renegotiation: c.Renegotiation, | |||
} | |||
} |
@@ -0,0 +1,50 @@ | |||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package | |||
// | |||
// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. | |||
// | |||
// This Source Code Form is subject to the terms of the Mozilla Public | |||
// License, v. 2.0. If a copy of the MPL was not distributed with this file, | |||
// You can obtain one at http://mozilla.org/MPL/2.0/. | |||
// +build go1.8 | |||
package mysql | |||
import ( | |||
"crypto/tls" | |||
"database/sql" | |||
"database/sql/driver" | |||
"errors" | |||
"fmt" | |||
) | |||
func cloneTLSConfig(c *tls.Config) *tls.Config { | |||
return c.Clone() | |||
} | |||
func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { | |||
dargs := make([]driver.Value, len(named)) | |||
for n, param := range named { | |||
if len(param.Name) > 0 { | |||
// TODO: support the use of Named Parameters #561 | |||
return nil, errors.New("mysql: driver does not support the use of Named Parameters") | |||
} | |||
dargs[n] = param.Value | |||
} | |||
return dargs, nil | |||
} | |||
func mapIsolationLevel(level driver.IsolationLevel) (string, error) { | |||
switch sql.IsolationLevel(level) { | |||
case sql.LevelRepeatableRead: | |||
return "REPEATABLE READ", nil | |||
case sql.LevelReadCommitted: | |||
return "READ COMMITTED", nil | |||
case sql.LevelReadUncommitted: | |||
return "READ UNCOMMITTED", nil | |||
case sql.LevelSerializable: | |||
return "SERIALIZABLE", nil | |||
default: | |||
return "", fmt.Errorf("mysql: unsupported isolation level: %v", level) | |||
} | |||
} |