@@ -288,17 +288,19 @@ | |||||
[[projects]] | [[projects]] | ||||
name = "github.com/go-xorm/builder" | name = "github.com/go-xorm/builder" | ||||
packages = ["."] | packages = ["."] | ||||
revision = "488224409dd8aa2ce7a5baf8d10d55764a913738" | |||||
revision = "dc8bf48f58fab2b4da338ffd25191905fd741b8f" | |||||
version = "v0.3.0" | |||||
[[projects]] | [[projects]] | ||||
name = "github.com/go-xorm/core" | name = "github.com/go-xorm/core" | ||||
packages = ["."] | packages = ["."] | ||||
revision = "cb1d0ca71f42d3ee1bf4aba7daa16099bc31a7e9" | |||||
revision = "c10e21e7e1cec20e09398f2dfae385e58c8df555" | |||||
version = "v0.6.0" | |||||
[[projects]] | [[projects]] | ||||
name = "github.com/go-xorm/xorm" | name = "github.com/go-xorm/xorm" | ||||
packages = ["."] | packages = ["."] | ||||
revision = "d4149d1eee0c2c488a74a5863fd9caf13d60fd03" | |||||
revision = "ad69f7d8f0861a29438154bb0a20b60501298480" | |||||
[[projects]] | [[projects]] | ||||
branch = "master" | branch = "master" | ||||
@@ -701,6 +703,6 @@ | |||||
[solve-meta] | [solve-meta] | ||||
analyzer-name = "dep" | analyzer-name = "dep" | ||||
analyzer-version = 1 | analyzer-version = 1 | ||||
inputs-digest = "59451a3ad1d449f75c5e9035daf542a377c5c4a397e219bebec0aa0007ab9c39" | |||||
inputs-digest = "5ae18d543bbb8186589c003422b333097d67bb5fed8b4c294be70c012ccffc94" | |||||
solver-name = "gps-cdcl" | solver-name = "gps-cdcl" | ||||
solver-version = 1 | solver-version = 1 |
@@ -33,7 +33,7 @@ ignored = ["google.golang.org/appengine*"] | |||||
[[override]] | [[override]] | ||||
name = "github.com/go-xorm/xorm" | name = "github.com/go-xorm/xorm" | ||||
#version = "0.6.5" | #version = "0.6.5" | ||||
revision = "d4149d1eee0c2c488a74a5863fd9caf13d60fd03" | |||||
revision = "ad69f7d8f0861a29438154bb0a20b60501298480" | |||||
[[override]] | [[override]] | ||||
name = "github.com/go-sql-driver/mysql" | name = "github.com/go-sql-driver/mysql" | ||||
@@ -1297,7 +1297,7 @@ func getParticipantsByIssueID(e Engine, issueID int64) ([]*User, error) { | |||||
And("`comment`.type = ?", CommentTypeComment). | And("`comment`.type = ?", CommentTypeComment). | ||||
And("`user`.is_active = ?", true). | And("`user`.is_active = ?", true). | ||||
And("`user`.prohibit_login = ?", false). | And("`user`.prohibit_login = ?", false). | ||||
Join("INNER", "user", "`user`.id = `comment`.poster_id"). | |||||
Join("INNER", "`user`", "`user`.id = `comment`.poster_id"). | |||||
Distinct("poster_id"). | Distinct("poster_id"). | ||||
Find(&userIDs); err != nil { | Find(&userIDs); err != nil { | ||||
return nil, fmt.Errorf("get poster IDs: %v", err) | return nil, fmt.Errorf("get poster IDs: %v", err) | ||||
@@ -166,7 +166,7 @@ func (issues IssueList) loadAssignees(e Engine) error { | |||||
var assignees = make(map[int64][]*User, len(issues)) | var assignees = make(map[int64][]*User, len(issues)) | ||||
rows, err := e.Table("issue_assignees"). | rows, err := e.Table("issue_assignees"). | ||||
Join("INNER", "user", "`user`.id = `issue_assignees`.assignee_id"). | |||||
Join("INNER", "`user`", "`user`.id = `issue_assignees`.assignee_id"). | |||||
In("`issue_assignees`.issue_id", issues.getIssueIDs()). | In("`issue_assignees`.issue_id", issues.getIssueIDs()). | ||||
Rows(new(AssigneeIssue)) | Rows(new(AssigneeIssue)) | ||||
if err != nil { | if err != nil { | ||||
@@ -67,7 +67,7 @@ func getIssueWatchers(e Engine, issueID int64) (watches []*IssueWatch, err error | |||||
Where("`issue_watch`.issue_id = ?", issueID). | Where("`issue_watch`.issue_id = ?", issueID). | ||||
And("`user`.is_active = ?", true). | And("`user`.is_active = ?", true). | ||||
And("`user`.prohibit_login = ?", false). | And("`user`.prohibit_login = ?", false). | ||||
Join("INNER", "user", "`user`.id = `issue_watch`.user_id"). | |||||
Join("INNER", "`user`", "`user`.id = `issue_watch`.user_id"). | |||||
Find(&watches) | Find(&watches) | ||||
return | return | ||||
} | } | ||||
@@ -383,7 +383,7 @@ func GetOwnedOrgsByUserIDDesc(userID int64, desc string) ([]*User, error) { | |||||
func GetOrgUsersByUserID(uid int64, all bool) ([]*OrgUser, error) { | func GetOrgUsersByUserID(uid int64, all bool) ([]*OrgUser, error) { | ||||
ous := make([]*OrgUser, 0, 10) | ous := make([]*OrgUser, 0, 10) | ||||
sess := x. | sess := x. | ||||
Join("LEFT", "user", "`org_user`.org_id=`user`.id"). | |||||
Join("LEFT", "`user`", "`org_user`.org_id=`user`.id"). | |||||
Where("`org_user`.uid=?", uid) | Where("`org_user`.uid=?", uid) | ||||
if !all { | if !all { | ||||
// Only show public organizations | // Only show public organizations | ||||
@@ -575,7 +575,7 @@ func (org *User) getUserTeams(e Engine, userID int64, cols ...string) ([]*Team, | |||||
return teams, e. | return teams, e. | ||||
Where("`team_user`.org_id = ?", org.ID). | Where("`team_user`.org_id = ?", org.ID). | ||||
Join("INNER", "team_user", "`team_user`.team_id = team.id"). | Join("INNER", "team_user", "`team_user`.team_id = team.id"). | ||||
Join("INNER", "user", "`user`.id=team_user.uid"). | |||||
Join("INNER", "`user`", "`user`.id=team_user.uid"). | |||||
And("`team_user`.uid = ?", userID). | And("`team_user`.uid = ?", userID). | ||||
Asc("`user`.name"). | Asc("`user`.name"). | ||||
Cols(cols...). | Cols(cols...). | ||||
@@ -1958,7 +1958,7 @@ func DeleteRepository(doer *User, uid, repoID int64) error { | |||||
func GetRepositoryByOwnerAndName(ownerName, repoName string) (*Repository, error) { | func GetRepositoryByOwnerAndName(ownerName, repoName string) (*Repository, error) { | ||||
var repo Repository | var repo Repository | ||||
has, err := x.Select("repository.*"). | has, err := x.Select("repository.*"). | ||||
Join("INNER", "user", "`user`.id = repository.owner_id"). | |||||
Join("INNER", "`user`", "`user`.id = repository.owner_id"). | |||||
Where("repository.lower_name = ?", strings.ToLower(repoName)). | Where("repository.lower_name = ?", strings.ToLower(repoName)). | ||||
And("`user`.lower_name = ?", strings.ToLower(ownerName)). | And("`user`.lower_name = ?", strings.ToLower(ownerName)). | ||||
Get(&repo) | Get(&repo) | ||||
@@ -54,7 +54,7 @@ func getWatchers(e Engine, repoID int64) ([]*Watch, error) { | |||||
return watches, e.Where("`watch`.repo_id=?", repoID). | return watches, e.Where("`watch`.repo_id=?", repoID). | ||||
And("`user`.is_active=?", true). | And("`user`.is_active=?", true). | ||||
And("`user`.prohibit_login=?", false). | And("`user`.prohibit_login=?", false). | ||||
Join("INNER", "user", "`user`.id = `watch`.user_id"). | |||||
Join("INNER", "`user`", "`user`.id = `watch`.user_id"). | |||||
Find(&watches) | Find(&watches) | ||||
} | } | ||||
@@ -374,9 +374,9 @@ func (u *User) GetFollowers(page int) ([]*User, error) { | |||||
Limit(ItemsPerPage, (page-1)*ItemsPerPage). | Limit(ItemsPerPage, (page-1)*ItemsPerPage). | ||||
Where("follow.follow_id=?", u.ID) | Where("follow.follow_id=?", u.ID) | ||||
if setting.UsePostgreSQL { | if setting.UsePostgreSQL { | ||||
sess = sess.Join("LEFT", "follow", `"user".id=follow.user_id`) | |||||
sess = sess.Join("LEFT", "follow", "`user`.id=follow.user_id") | |||||
} else { | } else { | ||||
sess = sess.Join("LEFT", "follow", "user.id=follow.user_id") | |||||
sess = sess.Join("LEFT", "follow", "`user`.id=follow.user_id") | |||||
} | } | ||||
return users, sess.Find(&users) | return users, sess.Find(&users) | ||||
} | } | ||||
@@ -393,9 +393,9 @@ func (u *User) GetFollowing(page int) ([]*User, error) { | |||||
Limit(ItemsPerPage, (page-1)*ItemsPerPage). | Limit(ItemsPerPage, (page-1)*ItemsPerPage). | ||||
Where("follow.user_id=?", u.ID) | Where("follow.user_id=?", u.ID) | ||||
if setting.UsePostgreSQL { | if setting.UsePostgreSQL { | ||||
sess = sess.Join("LEFT", "follow", `"user".id=follow.follow_id`) | |||||
sess = sess.Join("LEFT", "follow", "`user`.id=follow.follow_id") | |||||
} else { | } else { | ||||
sess = sess.Join("LEFT", "follow", "user.id=follow.follow_id") | |||||
sess = sess.Join("LEFT", "follow", "`user`.id=follow.follow_id") | |||||
} | } | ||||
return users, sess.Find(&users) | return users, sess.Find(&users) | ||||
} | } | ||||
@@ -4,6 +4,10 @@ | |||||
package builder | package builder | ||||
import ( | |||||
"fmt" | |||||
) | |||||
type optype byte | type optype byte | ||||
const ( | const ( | ||||
@@ -29,6 +33,9 @@ type Builder struct { | |||||
joins []join | joins []join | ||||
inserts Eq | inserts Eq | ||||
updates []Eq | updates []Eq | ||||
orderBy string | |||||
groupBy string | |||||
having string | |||||
} | } | ||||
// Select creates a select Builder | // Select creates a select Builder | ||||
@@ -67,6 +74,11 @@ func (b *Builder) From(tableName string) *Builder { | |||||
return b | return b | ||||
} | } | ||||
// TableName returns the table name | |||||
func (b *Builder) TableName() string { | |||||
return b.tableName | |||||
} | |||||
// Into sets insert table name | // Into sets insert table name | ||||
func (b *Builder) Into(tableName string) *Builder { | func (b *Builder) Into(tableName string) *Builder { | ||||
b.tableName = tableName | b.tableName = tableName | ||||
@@ -178,6 +190,33 @@ func (b *Builder) ToSQL() (string, []interface{}, error) { | |||||
return w.writer.String(), w.args, nil | return w.writer.String(), w.args, nil | ||||
} | } | ||||
// ConvertPlaceholder replaces ? to $1, $2 ... or :1, :2 ... according prefix | |||||
func ConvertPlaceholder(sql, prefix string) (string, error) { | |||||
buf := StringBuilder{} | |||||
var j, start = 0, 0 | |||||
for i := 0; i < len(sql); i++ { | |||||
if sql[i] == '?' { | |||||
_, err := buf.WriteString(sql[start:i]) | |||||
if err != nil { | |||||
return "", err | |||||
} | |||||
start = i + 1 | |||||
_, err = buf.WriteString(prefix) | |||||
if err != nil { | |||||
return "", err | |||||
} | |||||
j = j + 1 | |||||
_, err = buf.WriteString(fmt.Sprintf("%d", j)) | |||||
if err != nil { | |||||
return "", err | |||||
} | |||||
} | |||||
} | |||||
return buf.String(), nil | |||||
} | |||||
// ToSQL convert a builder or condtions to SQL and args | // ToSQL convert a builder or condtions to SQL and args | ||||
func ToSQL(cond interface{}) (string, []interface{}, error) { | func ToSQL(cond interface{}) (string, []interface{}, error) { | ||||
switch cond.(type) { | switch cond.(type) { | ||||
@@ -15,7 +15,7 @@ func (b *Builder) insertWriteTo(w Writer) error { | |||||
return errors.New("no table indicated") | return errors.New("no table indicated") | ||||
} | } | ||||
if len(b.inserts) <= 0 { | if len(b.inserts) <= 0 { | ||||
return errors.New("no column to be update") | |||||
return errors.New("no column to be insert") | |||||
} | } | ||||
if _, err := fmt.Fprintf(w, "INSERT INTO %s (", b.tableName); err != nil { | if _, err := fmt.Fprintf(w, "INSERT INTO %s (", b.tableName); err != nil { | ||||
@@ -26,7 +26,9 @@ func (b *Builder) insertWriteTo(w Writer) error { | |||||
var bs []byte | var bs []byte | ||||
var valBuffer = bytes.NewBuffer(bs) | var valBuffer = bytes.NewBuffer(bs) | ||||
var i = 0 | var i = 0 | ||||
for col, value := range b.inserts { | |||||
for _, col := range b.inserts.sortedKeys() { | |||||
value := b.inserts[col] | |||||
fmt.Fprint(w, col) | fmt.Fprint(w, col) | ||||
if e, ok := value.(expr); ok { | if e, ok := value.(expr); ok { | ||||
fmt.Fprint(valBuffer, e.sql) | fmt.Fprint(valBuffer, e.sql) | ||||
@@ -34,24 +34,65 @@ func (b *Builder) selectWriteTo(w Writer) error { | |||||
} | } | ||||
} | } | ||||
if _, err := fmt.Fprintf(w, " FROM %s", b.tableName); err != nil { | |||||
if _, err := fmt.Fprint(w, " FROM ", b.tableName); err != nil { | |||||
return err | return err | ||||
} | } | ||||
for _, v := range b.joins { | for _, v := range b.joins { | ||||
fmt.Fprintf(w, " %s JOIN %s ON ", v.joinType, v.joinTable) | |||||
if _, err := fmt.Fprintf(w, " %s JOIN %s ON ", v.joinType, v.joinTable); err != nil { | |||||
return err | |||||
} | |||||
if err := v.joinCond.WriteTo(w); err != nil { | if err := v.joinCond.WriteTo(w); err != nil { | ||||
return err | return err | ||||
} | } | ||||
} | } | ||||
if !b.cond.IsValid() { | |||||
return nil | |||||
if b.cond.IsValid() { | |||||
if _, err := fmt.Fprint(w, " WHERE "); err != nil { | |||||
return err | |||||
} | |||||
if err := b.cond.WriteTo(w); err != nil { | |||||
return err | |||||
} | |||||
} | } | ||||
if _, err := fmt.Fprint(w, " WHERE "); err != nil { | |||||
return err | |||||
if len(b.groupBy) > 0 { | |||||
if _, err := fmt.Fprint(w, " GROUP BY ", b.groupBy); err != nil { | |||||
return err | |||||
} | |||||
} | } | ||||
return b.cond.WriteTo(w) | |||||
if len(b.having) > 0 { | |||||
if _, err := fmt.Fprint(w, " HAVING ", b.having); err != nil { | |||||
return err | |||||
} | |||||
} | |||||
if len(b.orderBy) > 0 { | |||||
if _, err := fmt.Fprint(w, " ORDER BY ", b.orderBy); err != nil { | |||||
return err | |||||
} | |||||
} | |||||
return nil | |||||
} | |||||
// OrderBy orderBy SQL | |||||
func (b *Builder) OrderBy(orderBy string) *Builder { | |||||
b.orderBy = orderBy | |||||
return b | |||||
} | |||||
// GroupBy groupby SQL | |||||
func (b *Builder) GroupBy(groupby string) *Builder { | |||||
b.groupBy = groupby | |||||
return b | |||||
} | |||||
// Having having SQL | |||||
func (b *Builder) Having(having string) *Builder { | |||||
b.having = having | |||||
return b | |||||
} | } |
@@ -5,7 +5,6 @@ | |||||
package builder | package builder | ||||
import ( | import ( | ||||
"bytes" | |||||
"io" | "io" | ||||
) | ) | ||||
@@ -19,15 +18,15 @@ var _ Writer = NewWriter() | |||||
// BytesWriter implments Writer and save SQL in bytes.Buffer | // BytesWriter implments Writer and save SQL in bytes.Buffer | ||||
type BytesWriter struct { | type BytesWriter struct { | ||||
writer *bytes.Buffer | |||||
buffer []byte | |||||
writer *StringBuilder | |||||
args []interface{} | args []interface{} | ||||
} | } | ||||
// NewWriter creates a new string writer | // NewWriter creates a new string writer | ||||
func NewWriter() *BytesWriter { | func NewWriter() *BytesWriter { | ||||
w := &BytesWriter{} | |||||
w.writer = bytes.NewBuffer(w.buffer) | |||||
w := &BytesWriter{ | |||||
writer: &StringBuilder{}, | |||||
} | |||||
return w | return w | ||||
} | } | ||||
@@ -10,7 +10,13 @@ import "fmt" | |||||
func WriteMap(w Writer, data map[string]interface{}, op string) error { | func WriteMap(w Writer, data map[string]interface{}, op string) error { | ||||
var args = make([]interface{}, 0, len(data)) | var args = make([]interface{}, 0, len(data)) | ||||
var i = 0 | var i = 0 | ||||
for k, v := range data { | |||||
keys := make([]string, 0, len(data)) | |||||
for k := range data { | |||||
keys = append(keys, k) | |||||
} | |||||
for _, k := range keys { | |||||
v := data[k] | |||||
switch v.(type) { | switch v.(type) { | ||||
case expr: | case expr: | ||||
if _, err := fmt.Fprintf(w, "%s%s(", k, op); err != nil { | if _, err := fmt.Fprintf(w, "%s%s(", k, op); err != nil { | ||||
@@ -4,7 +4,10 @@ | |||||
package builder | package builder | ||||
import "fmt" | |||||
import ( | |||||
"fmt" | |||||
"sort" | |||||
) | |||||
// Incr implements a type used by Eq | // Incr implements a type used by Eq | ||||
type Incr int | type Incr int | ||||
@@ -19,7 +22,8 @@ var _ Cond = Eq{} | |||||
func (eq Eq) opWriteTo(op string, w Writer) error { | func (eq Eq) opWriteTo(op string, w Writer) error { | ||||
var i = 0 | var i = 0 | ||||
for k, v := range eq { | |||||
for _, k := range eq.sortedKeys() { | |||||
v := eq[k] | |||||
switch v.(type) { | switch v.(type) { | ||||
case []int, []int64, []string, []int32, []int16, []int8, []uint, []uint64, []uint32, []uint16, []interface{}: | case []int, []int64, []string, []int32, []int16, []int8, []uint, []uint64, []uint32, []uint16, []interface{}: | ||||
if err := In(k, v).WriteTo(w); err != nil { | if err := In(k, v).WriteTo(w); err != nil { | ||||
@@ -94,3 +98,15 @@ func (eq Eq) Or(conds ...Cond) Cond { | |||||
func (eq Eq) IsValid() bool { | func (eq Eq) IsValid() bool { | ||||
return len(eq) > 0 | return len(eq) > 0 | ||||
} | } | ||||
// sortedKeys returns all keys of this Eq sorted with sort.Strings. | |||||
// It is used internally for consistent ordering when generating | |||||
// SQL, see https://github.com/go-xorm/builder/issues/10 | |||||
func (eq Eq) sortedKeys() []string { | |||||
keys := make([]string, 0, len(eq)) | |||||
for key := range eq { | |||||
keys = append(keys, key) | |||||
} | |||||
sort.Strings(keys) | |||||
return keys | |||||
} |
@@ -16,7 +16,7 @@ func (like Like) WriteTo(w Writer) error { | |||||
if _, err := fmt.Fprintf(w, "%s LIKE ?", like[0]); err != nil { | if _, err := fmt.Fprintf(w, "%s LIKE ?", like[0]); err != nil { | ||||
return err | return err | ||||
} | } | ||||
// FIXME: if use other regular express, this will be failed. but for compitable, keep this | |||||
// FIXME: if use other regular express, this will be failed. but for compatible, keep this | |||||
if like[1][0] == '%' || like[1][len(like[1])-1] == '%' { | if like[1][0] == '%' || like[1][len(like[1])-1] == '%' { | ||||
w.Append(like[1]) | w.Append(like[1]) | ||||
} else { | } else { | ||||
@@ -4,7 +4,10 @@ | |||||
package builder | package builder | ||||
import "fmt" | |||||
import ( | |||||
"fmt" | |||||
"sort" | |||||
) | |||||
// Neq defines not equal conditions | // Neq defines not equal conditions | ||||
type Neq map[string]interface{} | type Neq map[string]interface{} | ||||
@@ -15,7 +18,8 @@ var _ Cond = Neq{} | |||||
func (neq Neq) WriteTo(w Writer) error { | func (neq Neq) WriteTo(w Writer) error { | ||||
var args = make([]interface{}, 0, len(neq)) | var args = make([]interface{}, 0, len(neq)) | ||||
var i = 0 | var i = 0 | ||||
for k, v := range neq { | |||||
for _, k := range neq.sortedKeys() { | |||||
v := neq[k] | |||||
switch v.(type) { | switch v.(type) { | ||||
case []int, []int64, []string, []int32, []int16, []int8: | case []int, []int64, []string, []int32, []int16, []int8: | ||||
if err := NotIn(k, v).WriteTo(w); err != nil { | if err := NotIn(k, v).WriteTo(w); err != nil { | ||||
@@ -76,3 +80,15 @@ func (neq Neq) Or(conds ...Cond) Cond { | |||||
func (neq Neq) IsValid() bool { | func (neq Neq) IsValid() bool { | ||||
return len(neq) > 0 | return len(neq) > 0 | ||||
} | } | ||||
// sortedKeys returns all keys of this Neq sorted with sort.Strings. | |||||
// It is used internally for consistent ordering when generating | |||||
// SQL, see https://github.com/go-xorm/builder/issues/10 | |||||
func (neq Neq) sortedKeys() []string { | |||||
keys := make([]string, 0, len(neq)) | |||||
for key := range neq { | |||||
keys = append(keys, key) | |||||
} | |||||
sort.Strings(keys) | |||||
return keys | |||||
} |
@@ -21,6 +21,18 @@ func (not Not) WriteTo(w Writer) error { | |||||
if _, err := fmt.Fprint(w, "("); err != nil { | if _, err := fmt.Fprint(w, "("); err != nil { | ||||
return err | return err | ||||
} | } | ||||
case Eq: | |||||
if len(not[0].(Eq)) > 1 { | |||||
if _, err := fmt.Fprint(w, "("); err != nil { | |||||
return err | |||||
} | |||||
} | |||||
case Neq: | |||||
if len(not[0].(Neq)) > 1 { | |||||
if _, err := fmt.Fprint(w, "("); err != nil { | |||||
return err | |||||
} | |||||
} | |||||
} | } | ||||
if err := not[0].WriteTo(w); err != nil { | if err := not[0].WriteTo(w); err != nil { | ||||
@@ -32,6 +44,18 @@ func (not Not) WriteTo(w Writer) error { | |||||
if _, err := fmt.Fprint(w, ")"); err != nil { | if _, err := fmt.Fprint(w, ")"); err != nil { | ||||
return err | return err | ||||
} | } | ||||
case Eq: | |||||
if len(not[0].(Eq)) > 1 { | |||||
if _, err := fmt.Fprint(w, ")"); err != nil { | |||||
return err | |||||
} | |||||
} | |||||
case Neq: | |||||
if len(not[0].(Neq)) > 1 { | |||||
if _, err := fmt.Fprint(w, ")"); err != nil { | |||||
return err | |||||
} | |||||
} | |||||
} | } | ||||
return nil | return nil | ||||
@@ -0,0 +1,119 @@ | |||||
// Copyright 2017 The Go Authors. All rights reserved. | |||||
// Use of this source code is governed by a BSD-style | |||||
// license that can be found in the LICENSE file. | |||||
package builder | |||||
import ( | |||||
"unicode/utf8" | |||||
"unsafe" | |||||
) | |||||
// A StringBuilder is used to efficiently build a string using Write methods. | |||||
// It minimizes memory copying. The zero value is ready to use. | |||||
// Do not copy a non-zero Builder. | |||||
type StringBuilder struct { | |||||
addr *StringBuilder // of receiver, to detect copies by value | |||||
buf []byte | |||||
} | |||||
// noescape hides a pointer from escape analysis. noescape is | |||||
// the identity function but escape analysis doesn't think the | |||||
// output depends on the input. noescape is inlined and currently | |||||
// compiles down to zero instructions. | |||||
// USE CAREFULLY! | |||||
// This was copied from the runtime; see issues 23382 and 7921. | |||||
//go:nosplit | |||||
func noescape(p unsafe.Pointer) unsafe.Pointer { | |||||
x := uintptr(p) | |||||
return unsafe.Pointer(x ^ 0) | |||||
} | |||||
func (b *StringBuilder) copyCheck() { | |||||
if b.addr == nil { | |||||
// This hack works around a failing of Go's escape analysis | |||||
// that was causing b to escape and be heap allocated. | |||||
// See issue 23382. | |||||
// TODO: once issue 7921 is fixed, this should be reverted to | |||||
// just "b.addr = b". | |||||
b.addr = (*StringBuilder)(noescape(unsafe.Pointer(b))) | |||||
} else if b.addr != b { | |||||
panic("strings: illegal use of non-zero Builder copied by value") | |||||
} | |||||
} | |||||
// String returns the accumulated string. | |||||
func (b *StringBuilder) String() string { | |||||
return *(*string)(unsafe.Pointer(&b.buf)) | |||||
} | |||||
// Len returns the number of accumulated bytes; b.Len() == len(b.String()). | |||||
func (b *StringBuilder) Len() int { return len(b.buf) } | |||||
// Reset resets the Builder to be empty. | |||||
func (b *StringBuilder) Reset() { | |||||
b.addr = nil | |||||
b.buf = nil | |||||
} | |||||
// grow copies the buffer to a new, larger buffer so that there are at least n | |||||
// bytes of capacity beyond len(b.buf). | |||||
func (b *StringBuilder) grow(n int) { | |||||
buf := make([]byte, len(b.buf), 2*cap(b.buf)+n) | |||||
copy(buf, b.buf) | |||||
b.buf = buf | |||||
} | |||||
// Grow grows b's capacity, if necessary, to guarantee space for | |||||
// another n bytes. After Grow(n), at least n bytes can be written to b | |||||
// without another allocation. If n is negative, Grow panics. | |||||
func (b *StringBuilder) Grow(n int) { | |||||
b.copyCheck() | |||||
if n < 0 { | |||||
panic("strings.Builder.Grow: negative count") | |||||
} | |||||
if cap(b.buf)-len(b.buf) < n { | |||||
b.grow(n) | |||||
} | |||||
} | |||||
// Write appends the contents of p to b's buffer. | |||||
// Write always returns len(p), nil. | |||||
func (b *StringBuilder) Write(p []byte) (int, error) { | |||||
b.copyCheck() | |||||
b.buf = append(b.buf, p...) | |||||
return len(p), nil | |||||
} | |||||
// WriteByte appends the byte c to b's buffer. | |||||
// The returned error is always nil. | |||||
func (b *StringBuilder) WriteByte(c byte) error { | |||||
b.copyCheck() | |||||
b.buf = append(b.buf, c) | |||||
return nil | |||||
} | |||||
// WriteRune appends the UTF-8 encoding of Unicode code point r to b's buffer. | |||||
// It returns the length of r and a nil error. | |||||
func (b *StringBuilder) WriteRune(r rune) (int, error) { | |||||
b.copyCheck() | |||||
if r < utf8.RuneSelf { | |||||
b.buf = append(b.buf, byte(r)) | |||||
return 1, nil | |||||
} | |||||
l := len(b.buf) | |||||
if cap(b.buf)-l < utf8.UTFMax { | |||||
b.grow(utf8.UTFMax) | |||||
} | |||||
n := utf8.EncodeRune(b.buf[l:l+utf8.UTFMax], r) | |||||
b.buf = b.buf[:l+n] | |||||
return n, nil | |||||
} | |||||
// WriteString appends the contents of s to b's buffer. | |||||
// It returns the length of s and a nil error. | |||||
func (b *StringBuilder) WriteString(s string) (int, error) { | |||||
b.copyCheck() | |||||
b.buf = append(b.buf, s...) | |||||
return len(s), nil | |||||
} |
@@ -147,12 +147,12 @@ func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) { | |||||
} | } | ||||
fieldValue = fieldValue.Elem().FieldByName(fieldPath[i+1]) | fieldValue = fieldValue.Elem().FieldByName(fieldPath[i+1]) | ||||
} else { | } else { | ||||
return nil, fmt.Errorf("field %v is not valid", col.FieldName) | |||||
return nil, fmt.Errorf("field %v is not valid", col.FieldName) | |||||
} | } | ||||
} | } | ||||
if !fieldValue.IsValid() { | if !fieldValue.IsValid() { | ||||
return nil, fmt.Errorf("field %v is not valid", col.FieldName) | |||||
return nil, fmt.Errorf("field %v is not valid", col.FieldName) | |||||
} | } | ||||
return &fieldValue, nil | return &fieldValue, nil | ||||
@@ -7,6 +7,11 @@ import ( | |||||
"fmt" | "fmt" | ||||
"reflect" | "reflect" | ||||
"regexp" | "regexp" | ||||
"sync" | |||||
) | |||||
var ( | |||||
DefaultCacheSize = 200 | |||||
) | ) | ||||
func MapToSlice(query string, mp interface{}) (string, []interface{}, error) { | func MapToSlice(query string, mp interface{}) (string, []interface{}, error) { | ||||
@@ -58,9 +63,16 @@ func StructToSlice(query string, st interface{}) (string, []interface{}, error) | |||||
return query, args, nil | return query, args, nil | ||||
} | } | ||||
type cacheStruct struct { | |||||
value reflect.Value | |||||
idx int | |||||
} | |||||
type DB struct { | type DB struct { | ||||
*sql.DB | *sql.DB | ||||
Mapper IMapper | |||||
Mapper IMapper | |||||
reflectCache map[reflect.Type]*cacheStruct | |||||
reflectCacheMutex sync.RWMutex | |||||
} | } | ||||
func Open(driverName, dataSourceName string) (*DB, error) { | func Open(driverName, dataSourceName string) (*DB, error) { | ||||
@@ -68,11 +80,32 @@ func Open(driverName, dataSourceName string) (*DB, error) { | |||||
if err != nil { | if err != nil { | ||||
return nil, err | return nil, err | ||||
} | } | ||||
return &DB{db, NewCacheMapper(&SnakeMapper{})}, nil | |||||
return &DB{ | |||||
DB: db, | |||||
Mapper: NewCacheMapper(&SnakeMapper{}), | |||||
reflectCache: make(map[reflect.Type]*cacheStruct), | |||||
}, nil | |||||
} | } | ||||
func FromDB(db *sql.DB) *DB { | func FromDB(db *sql.DB) *DB { | ||||
return &DB{db, NewCacheMapper(&SnakeMapper{})} | |||||
return &DB{ | |||||
DB: db, | |||||
Mapper: NewCacheMapper(&SnakeMapper{}), | |||||
reflectCache: make(map[reflect.Type]*cacheStruct), | |||||
} | |||||
} | |||||
func (db *DB) reflectNew(typ reflect.Type) reflect.Value { | |||||
db.reflectCacheMutex.Lock() | |||||
defer db.reflectCacheMutex.Unlock() | |||||
cs, ok := db.reflectCache[typ] | |||||
if !ok || cs.idx+1 > DefaultCacheSize-1 { | |||||
cs = &cacheStruct{reflect.MakeSlice(reflect.SliceOf(typ), DefaultCacheSize, DefaultCacheSize), 0} | |||||
db.reflectCache[typ] = cs | |||||
} else { | |||||
cs.idx = cs.idx + 1 | |||||
} | |||||
return cs.value.Index(cs.idx).Addr() | |||||
} | } | ||||
func (db *DB) Query(query string, args ...interface{}) (*Rows, error) { | func (db *DB) Query(query string, args ...interface{}) (*Rows, error) { | ||||
@@ -83,7 +116,7 @@ func (db *DB) Query(query string, args ...interface{}) (*Rows, error) { | |||||
} | } | ||||
return nil, err | return nil, err | ||||
} | } | ||||
return &Rows{rows, db.Mapper}, nil | |||||
return &Rows{rows, db}, nil | |||||
} | } | ||||
func (db *DB) QueryMap(query string, mp interface{}) (*Rows, error) { | func (db *DB) QueryMap(query string, mp interface{}) (*Rows, error) { | ||||
@@ -128,8 +161,8 @@ func (db *DB) QueryRowStruct(query string, st interface{}) *Row { | |||||
type Stmt struct { | type Stmt struct { | ||||
*sql.Stmt | *sql.Stmt | ||||
Mapper IMapper | |||||
names map[string]int | |||||
db *DB | |||||
names map[string]int | |||||
} | } | ||||
func (db *DB) Prepare(query string) (*Stmt, error) { | func (db *DB) Prepare(query string) (*Stmt, error) { | ||||
@@ -145,7 +178,7 @@ func (db *DB) Prepare(query string) (*Stmt, error) { | |||||
if err != nil { | if err != nil { | ||||
return nil, err | return nil, err | ||||
} | } | ||||
return &Stmt{stmt, db.Mapper, names}, nil | |||||
return &Stmt{stmt, db, names}, nil | |||||
} | } | ||||
func (s *Stmt) ExecMap(mp interface{}) (sql.Result, error) { | func (s *Stmt) ExecMap(mp interface{}) (sql.Result, error) { | ||||
@@ -179,7 +212,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { | |||||
if err != nil { | if err != nil { | ||||
return nil, err | return nil, err | ||||
} | } | ||||
return &Rows{rows, s.Mapper}, nil | |||||
return &Rows{rows, s.db}, nil | |||||
} | } | ||||
func (s *Stmt) QueryMap(mp interface{}) (*Rows, error) { | func (s *Stmt) QueryMap(mp interface{}) (*Rows, error) { | ||||
@@ -274,7 +307,7 @@ func (EmptyScanner) Scan(src interface{}) error { | |||||
type Tx struct { | type Tx struct { | ||||
*sql.Tx | *sql.Tx | ||||
Mapper IMapper | |||||
db *DB | |||||
} | } | ||||
func (db *DB) Begin() (*Tx, error) { | func (db *DB) Begin() (*Tx, error) { | ||||
@@ -282,7 +315,7 @@ func (db *DB) Begin() (*Tx, error) { | |||||
if err != nil { | if err != nil { | ||||
return nil, err | return nil, err | ||||
} | } | ||||
return &Tx{tx, db.Mapper}, nil | |||||
return &Tx{tx, db}, nil | |||||
} | } | ||||
func (tx *Tx) Prepare(query string) (*Stmt, error) { | func (tx *Tx) Prepare(query string) (*Stmt, error) { | ||||
@@ -298,7 +331,7 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) { | |||||
if err != nil { | if err != nil { | ||||
return nil, err | return nil, err | ||||
} | } | ||||
return &Stmt{stmt, tx.Mapper, names}, nil | |||||
return &Stmt{stmt, tx.db, names}, nil | |||||
} | } | ||||
func (tx *Tx) Stmt(stmt *Stmt) *Stmt { | func (tx *Tx) Stmt(stmt *Stmt) *Stmt { | ||||
@@ -327,7 +360,7 @@ func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) { | |||||
if err != nil { | if err != nil { | ||||
return nil, err | return nil, err | ||||
} | } | ||||
return &Rows{rows, tx.Mapper}, nil | |||||
return &Rows{rows, tx.db}, nil | |||||
} | } | ||||
func (tx *Tx) QueryMap(query string, mp interface{}) (*Rows, error) { | func (tx *Tx) QueryMap(query string, mp interface{}) (*Rows, error) { | ||||
@@ -74,6 +74,7 @@ type Dialect interface { | |||||
GetIndexes(tableName string) (map[string]*Index, error) | GetIndexes(tableName string) (map[string]*Index, error) | ||||
Filters() []Filter | Filters() []Filter | ||||
SetParams(params map[string]string) | |||||
} | } | ||||
func OpenDialect(dialect Dialect) (*DB, error) { | func OpenDialect(dialect Dialect) (*DB, error) { | ||||
@@ -148,7 +149,8 @@ func (db *Base) SupportDropIfExists() bool { | |||||
} | } | ||||
func (db *Base) DropTableSql(tableName string) string { | func (db *Base) DropTableSql(tableName string) string { | ||||
return fmt.Sprintf("DROP TABLE IF EXISTS `%s`", tableName) | |||||
quote := db.dialect.Quote | |||||
return fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName)) | |||||
} | } | ||||
func (db *Base) HasRecords(query string, args ...interface{}) (bool, error) { | func (db *Base) HasRecords(query string, args ...interface{}) (bool, error) { | ||||
@@ -289,6 +291,9 @@ func (b *Base) LogSQL(sql string, args []interface{}) { | |||||
} | } | ||||
} | } | ||||
func (b *Base) SetParams(params map[string]string) { | |||||
} | |||||
var ( | var ( | ||||
dialects = map[string]func() Dialect{} | dialects = map[string]func() Dialect{} | ||||
) | ) | ||||
@@ -37,9 +37,9 @@ func (q *Quoter) Quote(content string) string { | |||||
func (i *IdFilter) Do(sql string, dialect Dialect, table *Table) string { | func (i *IdFilter) Do(sql string, dialect Dialect, table *Table) string { | ||||
quoter := NewQuoter(dialect) | quoter := NewQuoter(dialect) | ||||
if table != nil && len(table.PrimaryKeys) == 1 { | if table != nil && len(table.PrimaryKeys) == 1 { | ||||
sql = strings.Replace(sql, "`(id)`", quoter.Quote(table.PrimaryKeys[0]), -1) | |||||
sql = strings.Replace(sql, quoter.Quote("(id)"), quoter.Quote(table.PrimaryKeys[0]), -1) | |||||
return strings.Replace(sql, "(id)", quoter.Quote(table.PrimaryKeys[0]), -1) | |||||
sql = strings.Replace(sql, " `(id)` ", " "+quoter.Quote(table.PrimaryKeys[0])+" ", -1) | |||||
sql = strings.Replace(sql, " "+quoter.Quote("(id)")+" ", " "+quoter.Quote(table.PrimaryKeys[0])+" ", -1) | |||||
return strings.Replace(sql, " (id) ", " "+quoter.Quote(table.PrimaryKeys[0])+" ", -1) | |||||
} | } | ||||
return sql | return sql | ||||
} | } | ||||
@@ -22,6 +22,8 @@ type Index struct { | |||||
func (index *Index) XName(tableName string) string { | func (index *Index) XName(tableName string) string { | ||||
if !strings.HasPrefix(index.Name, "UQE_") && | if !strings.HasPrefix(index.Name, "UQE_") && | ||||
!strings.HasPrefix(index.Name, "IDX_") { | !strings.HasPrefix(index.Name, "IDX_") { | ||||
tableName = strings.Replace(tableName, `"`, "", -1) | |||||
tableName = strings.Replace(tableName, `.`, "_", -1) | |||||
if index.Type == UniqueType { | if index.Type == UniqueType { | ||||
return fmt.Sprintf("UQE_%v_%v", tableName, index.Name) | return fmt.Sprintf("UQE_%v_%v", tableName, index.Name) | ||||
} | } | ||||
@@ -9,7 +9,7 @@ import ( | |||||
type Rows struct { | type Rows struct { | ||||
*sql.Rows | *sql.Rows | ||||
Mapper IMapper | |||||
db *DB | |||||
} | } | ||||
func (rs *Rows) ToMapString() ([]map[string]string, error) { | func (rs *Rows) ToMapString() ([]map[string]string, error) { | ||||
@@ -105,7 +105,7 @@ func (rs *Rows) ScanStructByName(dest interface{}) error { | |||||
newDest := make([]interface{}, len(cols)) | newDest := make([]interface{}, len(cols)) | ||||
var v EmptyScanner | var v EmptyScanner | ||||
for j, name := range cols { | for j, name := range cols { | ||||
f := fieldByName(vv.Elem(), rs.Mapper.Table2Obj(name)) | |||||
f := fieldByName(vv.Elem(), rs.db.Mapper.Table2Obj(name)) | |||||
if f.IsValid() { | if f.IsValid() { | ||||
newDest[j] = f.Addr().Interface() | newDest[j] = f.Addr().Interface() | ||||
} else { | } else { | ||||
@@ -116,36 +116,6 @@ func (rs *Rows) ScanStructByName(dest interface{}) error { | |||||
return rs.Rows.Scan(newDest...) | return rs.Rows.Scan(newDest...) | ||||
} | } | ||||
type cacheStruct struct { | |||||
value reflect.Value | |||||
idx int | |||||
} | |||||
var ( | |||||
reflectCache = make(map[reflect.Type]*cacheStruct) | |||||
reflectCacheMutex sync.RWMutex | |||||
) | |||||
func ReflectNew(typ reflect.Type) reflect.Value { | |||||
reflectCacheMutex.RLock() | |||||
cs, ok := reflectCache[typ] | |||||
reflectCacheMutex.RUnlock() | |||||
const newSize = 200 | |||||
if !ok || cs.idx+1 > newSize-1 { | |||||
cs = &cacheStruct{reflect.MakeSlice(reflect.SliceOf(typ), newSize, newSize), 0} | |||||
reflectCacheMutex.Lock() | |||||
reflectCache[typ] = cs | |||||
reflectCacheMutex.Unlock() | |||||
} else { | |||||
reflectCacheMutex.Lock() | |||||
cs.idx = cs.idx + 1 | |||||
reflectCacheMutex.Unlock() | |||||
} | |||||
return cs.value.Index(cs.idx).Addr() | |||||
} | |||||
// scan data to a slice's pointer, slice's length should equal to columns' number | // scan data to a slice's pointer, slice's length should equal to columns' number | ||||
func (rs *Rows) ScanSlice(dest interface{}) error { | func (rs *Rows) ScanSlice(dest interface{}) error { | ||||
vv := reflect.ValueOf(dest) | vv := reflect.ValueOf(dest) | ||||
@@ -197,9 +167,7 @@ func (rs *Rows) ScanMap(dest interface{}) error { | |||||
vvv := vv.Elem() | vvv := vv.Elem() | ||||
for i, _ := range cols { | for i, _ := range cols { | ||||
newDest[i] = ReflectNew(vvv.Type().Elem()).Interface() | |||||
//v := reflect.New(vvv.Type().Elem()) | |||||
//newDest[i] = v.Interface() | |||||
newDest[i] = rs.db.reflectNew(vvv.Type().Elem()).Interface() | |||||
} | } | ||||
err = rs.Rows.Scan(newDest...) | err = rs.Rows.Scan(newDest...) | ||||
@@ -215,32 +183,6 @@ func (rs *Rows) ScanMap(dest interface{}) error { | |||||
return nil | return nil | ||||
} | } | ||||
/*func (rs *Rows) ScanMap(dest interface{}) error { | |||||
vv := reflect.ValueOf(dest) | |||||
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { | |||||
return errors.New("dest should be a map's pointer") | |||||
} | |||||
cols, err := rs.Columns() | |||||
if err != nil { | |||||
return err | |||||
} | |||||
newDest := make([]interface{}, len(cols)) | |||||
err = rs.ScanSlice(newDest) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
vvv := vv.Elem() | |||||
for i, name := range cols { | |||||
vname := reflect.ValueOf(name) | |||||
vvv.SetMapIndex(vname, reflect.ValueOf(newDest[i]).Elem()) | |||||
} | |||||
return nil | |||||
}*/ | |||||
type Row struct { | type Row struct { | ||||
rows *Rows | rows *Rows | ||||
// One of these two will be non-nil: | // One of these two will be non-nil: | ||||
@@ -49,7 +49,6 @@ func NewTable(name string, t reflect.Type) *Table { | |||||
} | } | ||||
func (table *Table) columnsByName(name string) []*Column { | func (table *Table) columnsByName(name string) []*Column { | ||||
n := len(name) | n := len(name) | ||||
for k := range table.columnsMap { | for k := range table.columnsMap { | ||||
@@ -75,7 +74,6 @@ func (table *Table) GetColumn(name string) *Column { | |||||
} | } | ||||
func (table *Table) GetColumnIdx(name string, idx int) *Column { | func (table *Table) GetColumnIdx(name string, idx int) *Column { | ||||
cols := table.columnsByName(name) | cols := table.columnsByName(name) | ||||
if cols != nil && idx < len(cols) { | if cols != nil && idx < len(cols) { | ||||
@@ -69,15 +69,18 @@ var ( | |||||
Enum = "ENUM" | Enum = "ENUM" | ||||
Set = "SET" | Set = "SET" | ||||
Char = "CHAR" | |||||
Varchar = "VARCHAR" | |||||
NVarchar = "NVARCHAR" | |||||
TinyText = "TINYTEXT" | |||||
Text = "TEXT" | |||||
Clob = "CLOB" | |||||
MediumText = "MEDIUMTEXT" | |||||
LongText = "LONGTEXT" | |||||
Uuid = "UUID" | |||||
Char = "CHAR" | |||||
Varchar = "VARCHAR" | |||||
NVarchar = "NVARCHAR" | |||||
TinyText = "TINYTEXT" | |||||
Text = "TEXT" | |||||
NText = "NTEXT" | |||||
Clob = "CLOB" | |||||
MediumText = "MEDIUMTEXT" | |||||
LongText = "LONGTEXT" | |||||
Uuid = "UUID" | |||||
UniqueIdentifier = "UNIQUEIDENTIFIER" | |||||
SysName = "SYSNAME" | |||||
Date = "DATE" | Date = "DATE" | ||||
DateTime = "DATETIME" | DateTime = "DATETIME" | ||||
@@ -128,10 +131,12 @@ var ( | |||||
NVarchar: TEXT_TYPE, | NVarchar: TEXT_TYPE, | ||||
TinyText: TEXT_TYPE, | TinyText: TEXT_TYPE, | ||||
Text: TEXT_TYPE, | Text: TEXT_TYPE, | ||||
NText: TEXT_TYPE, | |||||
MediumText: TEXT_TYPE, | MediumText: TEXT_TYPE, | ||||
LongText: TEXT_TYPE, | LongText: TEXT_TYPE, | ||||
Uuid: TEXT_TYPE, | Uuid: TEXT_TYPE, | ||||
Clob: TEXT_TYPE, | Clob: TEXT_TYPE, | ||||
SysName: TEXT_TYPE, | |||||
Date: TIME_TYPE, | Date: TIME_TYPE, | ||||
DateTime: TIME_TYPE, | DateTime: TIME_TYPE, | ||||
@@ -148,11 +153,12 @@ var ( | |||||
Binary: BLOB_TYPE, | Binary: BLOB_TYPE, | ||||
VarBinary: BLOB_TYPE, | VarBinary: BLOB_TYPE, | ||||
TinyBlob: BLOB_TYPE, | |||||
Blob: BLOB_TYPE, | |||||
MediumBlob: BLOB_TYPE, | |||||
LongBlob: BLOB_TYPE, | |||||
Bytea: BLOB_TYPE, | |||||
TinyBlob: BLOB_TYPE, | |||||
Blob: BLOB_TYPE, | |||||
MediumBlob: BLOB_TYPE, | |||||
LongBlob: BLOB_TYPE, | |||||
Bytea: BLOB_TYPE, | |||||
UniqueIdentifier: BLOB_TYPE, | |||||
Bool: NUMERIC_TYPE, | Bool: NUMERIC_TYPE, | ||||
@@ -289,9 +295,9 @@ func SQLType2Type(st SQLType) reflect.Type { | |||||
return reflect.TypeOf(float32(1)) | return reflect.TypeOf(float32(1)) | ||||
case Double: | case Double: | ||||
return reflect.TypeOf(float64(1)) | return reflect.TypeOf(float64(1)) | ||||
case Char, Varchar, NVarchar, TinyText, Text, MediumText, LongText, Enum, Set, Uuid, Clob: | |||||
case Char, Varchar, NVarchar, TinyText, Text, NText, MediumText, LongText, Enum, Set, Uuid, Clob, SysName: | |||||
return reflect.TypeOf("") | return reflect.TypeOf("") | ||||
case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary: | |||||
case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary, UniqueIdentifier: | |||||
return reflect.TypeOf([]byte{}) | return reflect.TypeOf([]byte{}) | ||||
case Bool: | case Bool: | ||||
return reflect.TypeOf(true) | return reflect.TypeOf(true) | ||||
@@ -172,12 +172,33 @@ type mysql struct { | |||||
allowAllFiles bool | allowAllFiles bool | ||||
allowOldPasswords bool | allowOldPasswords bool | ||||
clientFoundRows bool | clientFoundRows bool | ||||
rowFormat string | |||||
} | } | ||||
func (db *mysql) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { | func (db *mysql) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { | ||||
return db.Base.Init(d, db, uri, drivername, dataSourceName) | return db.Base.Init(d, db, uri, drivername, dataSourceName) | ||||
} | } | ||||
func (db *mysql) SetParams(params map[string]string) { | |||||
rowFormat, ok := params["rowFormat"] | |||||
if ok { | |||||
var t = strings.ToUpper(rowFormat) | |||||
switch t { | |||||
case "COMPACT": | |||||
fallthrough | |||||
case "REDUNDANT": | |||||
fallthrough | |||||
case "DYNAMIC": | |||||
fallthrough | |||||
case "COMPRESSED": | |||||
db.rowFormat = t | |||||
break | |||||
default: | |||||
break | |||||
} | |||||
} | |||||
} | |||||
func (db *mysql) SqlType(c *core.Column) string { | func (db *mysql) SqlType(c *core.Column) string { | ||||
var res string | var res string | ||||
switch t := c.SQLType.Name; t { | switch t := c.SQLType.Name; t { | ||||
@@ -487,6 +508,59 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*core.Index, error) { | |||||
return indexes, nil | return indexes, nil | ||||
} | } | ||||
func (db *mysql) CreateTableSql(table *core.Table, tableName, storeEngine, charset string) string { | |||||
var sql string | |||||
sql = "CREATE TABLE IF NOT EXISTS " | |||||
if tableName == "" { | |||||
tableName = table.Name | |||||
} | |||||
sql += db.Quote(tableName) | |||||
sql += " (" | |||||
if len(table.ColumnsSeq()) > 0 { | |||||
pkList := table.PrimaryKeys | |||||
for _, colName := range table.ColumnsSeq() { | |||||
col := table.GetColumn(colName) | |||||
if col.IsPrimaryKey && len(pkList) == 1 { | |||||
sql += col.String(db) | |||||
} else { | |||||
sql += col.StringNoPk(db) | |||||
} | |||||
sql = strings.TrimSpace(sql) | |||||
if len(col.Comment) > 0 { | |||||
sql += " COMMENT '" + col.Comment + "'" | |||||
} | |||||
sql += ", " | |||||
} | |||||
if len(pkList) > 1 { | |||||
sql += "PRIMARY KEY ( " | |||||
sql += db.Quote(strings.Join(pkList, db.Quote(","))) | |||||
sql += " ), " | |||||
} | |||||
sql = sql[:len(sql)-2] | |||||
} | |||||
sql += ")" | |||||
if storeEngine != "" { | |||||
sql += " ENGINE=" + storeEngine | |||||
} | |||||
if len(charset) == 0 { | |||||
charset = db.URI().Charset | |||||
} else if len(charset) > 0 { | |||||
sql += " DEFAULT CHARSET " + charset | |||||
} | |||||
if db.rowFormat != "" { | |||||
sql += " ROW_FORMAT=" + db.rowFormat | |||||
} | |||||
return sql | |||||
} | |||||
func (db *mysql) Filters() []core.Filter { | func (db *mysql) Filters() []core.Filter { | ||||
return []core.Filter{&core.IdFilter{}} | return []core.Filter{&core.IdFilter{}} | ||||
} | } | ||||
@@ -769,14 +769,21 @@ var ( | |||||
DefaultPostgresSchema = "public" | DefaultPostgresSchema = "public" | ||||
) | ) | ||||
const postgresPublicSchema = "public" | |||||
type postgres struct { | type postgres struct { | ||||
core.Base | core.Base | ||||
schema string | |||||
} | } | ||||
func (db *postgres) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { | func (db *postgres) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { | ||||
db.schema = DefaultPostgresSchema | |||||
return db.Base.Init(d, db, uri, drivername, dataSourceName) | |||||
err := db.Base.Init(d, db, uri, drivername, dataSourceName) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
if db.Schema == "" { | |||||
db.Schema = DefaultPostgresSchema | |||||
} | |||||
return nil | |||||
} | } | ||||
func (db *postgres) SqlType(c *core.Column) string { | func (db *postgres) SqlType(c *core.Column) string { | ||||
@@ -873,32 +880,42 @@ func (db *postgres) IndexOnTable() bool { | |||||
} | } | ||||
func (db *postgres) IndexCheckSql(tableName, idxName string) (string, []interface{}) { | func (db *postgres) IndexCheckSql(tableName, idxName string) (string, []interface{}) { | ||||
args := []interface{}{tableName, idxName} | |||||
if len(db.Schema) == 0 { | |||||
args := []interface{}{tableName, idxName} | |||||
return `SELECT indexname FROM pg_indexes WHERE tablename = ? AND indexname = ?`, args | |||||
} | |||||
args := []interface{}{db.Schema, tableName, idxName} | |||||
return `SELECT indexname FROM pg_indexes ` + | return `SELECT indexname FROM pg_indexes ` + | ||||
`WHERE tablename = ? AND indexname = ?`, args | |||||
`WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args | |||||
} | } | ||||
func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) { | func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) { | ||||
args := []interface{}{tableName} | |||||
return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args | |||||
} | |||||
if len(db.Schema) == 0 { | |||||
args := []interface{}{tableName} | |||||
return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args | |||||
} | |||||
/*func (db *postgres) ColumnCheckSql(tableName, colName string) (string, []interface{}) { | |||||
args := []interface{}{tableName, colName} | |||||
return "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = ?" + | |||||
" AND column_name = ?", args | |||||
}*/ | |||||
args := []interface{}{db.Schema, tableName} | |||||
return `SELECT tablename FROM pg_tables WHERE schemaname = ? AND tablename = ?`, args | |||||
} | |||||
func (db *postgres) ModifyColumnSql(tableName string, col *core.Column) string { | func (db *postgres) ModifyColumnSql(tableName string, col *core.Column) string { | ||||
return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s", | |||||
tableName, col.Name, db.SqlType(col)) | |||||
if len(db.Schema) == 0 { | |||||
return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s", | |||||
tableName, col.Name, db.SqlType(col)) | |||||
} | |||||
return fmt.Sprintf("alter table %s.%s ALTER COLUMN %s TYPE %s", | |||||
db.Schema, tableName, col.Name, db.SqlType(col)) | |||||
} | } | ||||
func (db *postgres) DropIndexSql(tableName string, index *core.Index) string { | func (db *postgres) DropIndexSql(tableName string, index *core.Index) string { | ||||
//var unique string | |||||
quote := db.Quote | quote := db.Quote | ||||
idxName := index.Name | idxName := index.Name | ||||
tableName = strings.Replace(tableName, `"`, "", -1) | |||||
tableName = strings.Replace(tableName, `.`, "_", -1) | |||||
if !strings.HasPrefix(idxName, "UQE_") && | if !strings.HasPrefix(idxName, "UQE_") && | ||||
!strings.HasPrefix(idxName, "IDX_") { | !strings.HasPrefix(idxName, "IDX_") { | ||||
if index.Type == core.UniqueType { | if index.Type == core.UniqueType { | ||||
@@ -907,13 +924,21 @@ func (db *postgres) DropIndexSql(tableName string, index *core.Index) string { | |||||
idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name) | idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name) | ||||
} | } | ||||
} | } | ||||
if db.Uri.Schema != "" { | |||||
idxName = db.Uri.Schema + "." + idxName | |||||
} | |||||
return fmt.Sprintf("DROP INDEX %v", quote(idxName)) | return fmt.Sprintf("DROP INDEX %v", quote(idxName)) | ||||
} | } | ||||
func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) { | func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) { | ||||
args := []interface{}{tableName, colName} | |||||
query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" + | |||||
" AND column_name = $2" | |||||
args := []interface{}{db.Schema, tableName, colName} | |||||
query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" + | |||||
" AND column_name = $3" | |||||
if len(db.Schema) == 0 { | |||||
args = []interface{}{tableName, colName} | |||||
query = "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" + | |||||
" AND column_name = $2" | |||||
} | |||||
db.LogSQL(query, args) | db.LogSQL(query, args) | ||||
rows, err := db.DB().Query(query, args...) | rows, err := db.DB().Query(query, args...) | ||||
@@ -926,8 +951,7 @@ func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) { | |||||
} | } | ||||
func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { | func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { | ||||
// FIXME: the schema should be replaced by user custom's | |||||
args := []interface{}{tableName, db.schema} | |||||
args := []interface{}{tableName} | |||||
s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, numeric_precision, numeric_precision_radix , | s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, numeric_precision, numeric_precision_radix , | ||||
CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey, | CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey, | ||||
CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey | CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey | ||||
@@ -938,7 +962,15 @@ FROM pg_attribute f | |||||
LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey) | LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey) | ||||
LEFT JOIN pg_class AS g ON p.confrelid = g.oid | LEFT JOIN pg_class AS g ON p.confrelid = g.oid | ||||
LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name | LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name | ||||
WHERE c.relkind = 'r'::char AND c.relname = $1 AND s.table_schema = $2 AND f.attnum > 0 ORDER BY f.attnum;` | |||||
WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;` | |||||
var f string | |||||
if len(db.Schema) != 0 { | |||||
args = append(args, db.Schema) | |||||
f = " AND s.table_schema = $2" | |||||
} | |||||
s = fmt.Sprintf(s, f) | |||||
db.LogSQL(s, args) | db.LogSQL(s, args) | ||||
rows, err := db.DB().Query(s, args...) | rows, err := db.DB().Query(s, args...) | ||||
@@ -1028,8 +1060,13 @@ WHERE c.relkind = 'r'::char AND c.relname = $1 AND s.table_schema = $2 AND f.att | |||||
} | } | ||||
func (db *postgres) GetTables() ([]*core.Table, error) { | func (db *postgres) GetTables() ([]*core.Table, error) { | ||||
args := []interface{}{db.schema} | |||||
s := fmt.Sprintf("SELECT tablename FROM pg_tables WHERE schemaname = $1") | |||||
args := []interface{}{} | |||||
s := "SELECT tablename FROM pg_tables" | |||||
if len(db.Schema) != 0 { | |||||
args = append(args, db.Schema) | |||||
s = s + " WHERE schemaname = $1" | |||||
} | |||||
db.LogSQL(s, args) | db.LogSQL(s, args) | ||||
rows, err := db.DB().Query(s, args...) | rows, err := db.DB().Query(s, args...) | ||||
@@ -1053,8 +1090,12 @@ func (db *postgres) GetTables() ([]*core.Table, error) { | |||||
} | } | ||||
func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) { | func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) { | ||||
args := []interface{}{db.schema, tableName} | |||||
s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE schemaname=$1 AND tablename=$2") | |||||
args := []interface{}{tableName} | |||||
s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1") | |||||
if len(db.Schema) != 0 { | |||||
args = append(args, db.Schema) | |||||
s = s + " AND schemaname=$2" | |||||
} | |||||
db.LogSQL(s, args) | db.LogSQL(s, args) | ||||
rows, err := db.DB().Query(s, args...) | rows, err := db.DB().Query(s, args...) | ||||
@@ -1182,3 +1223,15 @@ func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { | |||||
return db, nil | return db, nil | ||||
} | } | ||||
type pqDriverPgx struct { | |||||
pqDriver | |||||
} | |||||
func (pgx *pqDriverPgx) Parse(driverName, dataSourceName string) (*core.Uri, error) { | |||||
// Remove the leading characters for driver to work | |||||
if len(dataSourceName) >= 9 && dataSourceName[0] == 0 { | |||||
dataSourceName = dataSourceName[9:] | |||||
} | |||||
return pgx.pqDriver.Parse(driverName, dataSourceName) | |||||
} |
@@ -49,6 +49,35 @@ type Engine struct { | |||||
tagHandlers map[string]tagHandler | tagHandlers map[string]tagHandler | ||||
engineGroup *EngineGroup | engineGroup *EngineGroup | ||||
cachers map[string]core.Cacher | |||||
cacherLock sync.RWMutex | |||||
} | |||||
func (engine *Engine) setCacher(tableName string, cacher core.Cacher) { | |||||
engine.cacherLock.Lock() | |||||
engine.cachers[tableName] = cacher | |||||
engine.cacherLock.Unlock() | |||||
} | |||||
func (engine *Engine) SetCacher(tableName string, cacher core.Cacher) { | |||||
engine.setCacher(tableName, cacher) | |||||
} | |||||
func (engine *Engine) getCacher(tableName string) core.Cacher { | |||||
var cacher core.Cacher | |||||
var ok bool | |||||
engine.cacherLock.RLock() | |||||
cacher, ok = engine.cachers[tableName] | |||||
engine.cacherLock.RUnlock() | |||||
if !ok && !engine.disableGlobalCache { | |||||
cacher = engine.Cacher | |||||
} | |||||
return cacher | |||||
} | |||||
func (engine *Engine) GetCacher(tableName string) core.Cacher { | |||||
return engine.getCacher(tableName) | |||||
} | } | ||||
// BufferSize sets buffer size for iterate | // BufferSize sets buffer size for iterate | ||||
@@ -165,7 +194,7 @@ func (engine *Engine) Quote(value string) string { | |||||
} | } | ||||
// QuoteTo quotes string and writes into the buffer | // QuoteTo quotes string and writes into the buffer | ||||
func (engine *Engine) QuoteTo(buf *bytes.Buffer, value string) { | |||||
func (engine *Engine) QuoteTo(buf *builder.StringBuilder, value string) { | |||||
if buf == nil { | if buf == nil { | ||||
return | return | ||||
} | } | ||||
@@ -245,13 +274,7 @@ func (engine *Engine) NoCascade() *Session { | |||||
// MapCacher Set a table use a special cacher | // MapCacher Set a table use a special cacher | ||||
func (engine *Engine) MapCacher(bean interface{}, cacher core.Cacher) error { | func (engine *Engine) MapCacher(bean interface{}, cacher core.Cacher) error { | ||||
v := rValue(bean) | |||||
tb, err := engine.autoMapType(v) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
tb.Cacher = cacher | |||||
engine.setCacher(engine.TableName(bean, true), cacher) | |||||
return nil | return nil | ||||
} | } | ||||
@@ -536,33 +559,6 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D | |||||
return nil | return nil | ||||
} | } | ||||
func (engine *Engine) tableName(beanOrTableName interface{}) (string, error) { | |||||
v := rValue(beanOrTableName) | |||||
if v.Type().Kind() == reflect.String { | |||||
return beanOrTableName.(string), nil | |||||
} else if v.Type().Kind() == reflect.Struct { | |||||
return engine.tbName(v), nil | |||||
} | |||||
return "", errors.New("bean should be a struct or struct's point") | |||||
} | |||||
func (engine *Engine) tbName(v reflect.Value) string { | |||||
if tb, ok := v.Interface().(TableName); ok { | |||||
return tb.TableName() | |||||
} | |||||
if v.Type().Kind() == reflect.Ptr { | |||||
if tb, ok := reflect.Indirect(v).Interface().(TableName); ok { | |||||
return tb.TableName() | |||||
} | |||||
} else if v.CanAddr() { | |||||
if tb, ok := v.Addr().Interface().(TableName); ok { | |||||
return tb.TableName() | |||||
} | |||||
} | |||||
return engine.TableMapper.Obj2Table(reflect.Indirect(v).Type().Name()) | |||||
} | |||||
// Cascade use cascade or not | // Cascade use cascade or not | ||||
func (engine *Engine) Cascade(trueOrFalse ...bool) *Session { | func (engine *Engine) Cascade(trueOrFalse ...bool) *Session { | ||||
session := engine.NewSession() | session := engine.NewSession() | ||||
@@ -846,7 +842,7 @@ func (engine *Engine) TableInfo(bean interface{}) *Table { | |||||
if err != nil { | if err != nil { | ||||
engine.logger.Error(err) | engine.logger.Error(err) | ||||
} | } | ||||
return &Table{tb, engine.tbName(v)} | |||||
return &Table{tb, engine.TableName(bean)} | |||||
} | } | ||||
func addIndex(indexName string, table *core.Table, col *core.Column, indexType int) { | func addIndex(indexName string, table *core.Table, col *core.Column, indexType int) { | ||||
@@ -861,15 +857,6 @@ func addIndex(indexName string, table *core.Table, col *core.Column, indexType i | |||||
} | } | ||||
} | } | ||||
func (engine *Engine) newTable() *core.Table { | |||||
table := core.NewEmptyTable() | |||||
if !engine.disableGlobalCache { | |||||
table.Cacher = engine.Cacher | |||||
} | |||||
return table | |||||
} | |||||
// TableName table name interface to define customerize table name | // TableName table name interface to define customerize table name | ||||
type TableName interface { | type TableName interface { | ||||
TableName() string | TableName() string | ||||
@@ -881,21 +868,9 @@ var ( | |||||
func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { | func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { | ||||
t := v.Type() | t := v.Type() | ||||
table := engine.newTable() | |||||
if tb, ok := v.Interface().(TableName); ok { | |||||
table.Name = tb.TableName() | |||||
} else { | |||||
if v.CanAddr() { | |||||
if tb, ok = v.Addr().Interface().(TableName); ok { | |||||
table.Name = tb.TableName() | |||||
} | |||||
} | |||||
if table.Name == "" { | |||||
table.Name = engine.TableMapper.Obj2Table(t.Name()) | |||||
} | |||||
} | |||||
table := core.NewEmptyTable() | |||||
table.Type = t | table.Type = t | ||||
table.Name = engine.tbNameForMap(v) | |||||
var idFieldColName string | var idFieldColName string | ||||
var hasCacheTag, hasNoCacheTag bool | var hasCacheTag, hasNoCacheTag bool | ||||
@@ -1049,15 +1024,15 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { | |||||
if hasCacheTag { | if hasCacheTag { | ||||
if engine.Cacher != nil { // !nash! use engine's cacher if provided | if engine.Cacher != nil { // !nash! use engine's cacher if provided | ||||
engine.logger.Info("enable cache on table:", table.Name) | engine.logger.Info("enable cache on table:", table.Name) | ||||
table.Cacher = engine.Cacher | |||||
engine.setCacher(table.Name, engine.Cacher) | |||||
} else { | } else { | ||||
engine.logger.Info("enable LRU cache on table:", table.Name) | engine.logger.Info("enable LRU cache on table:", table.Name) | ||||
table.Cacher = NewLRUCacher2(NewMemoryStore(), time.Hour, 10000) // !nashtsai! HACK use LRU cacher for now | |||||
engine.setCacher(table.Name, NewLRUCacher2(NewMemoryStore(), time.Hour, 10000)) | |||||
} | } | ||||
} | } | ||||
if hasNoCacheTag { | if hasNoCacheTag { | ||||
engine.logger.Info("no cache on table:", table.Name) | |||||
table.Cacher = nil | |||||
engine.logger.Info("disable cache on table:", table.Name) | |||||
engine.setCacher(table.Name, nil) | |||||
} | } | ||||
return table, nil | return table, nil | ||||
@@ -1116,7 +1091,25 @@ func (engine *Engine) idOfV(rv reflect.Value) (core.PK, error) { | |||||
pk := make([]interface{}, len(table.PrimaryKeys)) | pk := make([]interface{}, len(table.PrimaryKeys)) | ||||
for i, col := range table.PKColumns() { | for i, col := range table.PKColumns() { | ||||
var err error | var err error | ||||
pkField := v.FieldByName(col.FieldName) | |||||
fieldName := col.FieldName | |||||
for { | |||||
parts := strings.SplitN(fieldName, ".", 2) | |||||
if len(parts) == 1 { | |||||
break | |||||
} | |||||
v = v.FieldByName(parts[0]) | |||||
if v.Kind() == reflect.Ptr { | |||||
v = v.Elem() | |||||
} | |||||
if v.Kind() != reflect.Struct { | |||||
return nil, ErrUnSupportedType | |||||
} | |||||
fieldName = parts[1] | |||||
} | |||||
pkField := v.FieldByName(fieldName) | |||||
switch pkField.Kind() { | switch pkField.Kind() { | ||||
case reflect.String: | case reflect.String: | ||||
pk[i], err = engine.idTypeAssertion(col, pkField.String()) | pk[i], err = engine.idTypeAssertion(col, pkField.String()) | ||||
@@ -1162,26 +1155,10 @@ func (engine *Engine) CreateUniques(bean interface{}) error { | |||||
return session.CreateUniques(bean) | return session.CreateUniques(bean) | ||||
} | } | ||||
func (engine *Engine) getCacher2(table *core.Table) core.Cacher { | |||||
return table.Cacher | |||||
} | |||||
// ClearCacheBean if enabled cache, clear the cache bean | // ClearCacheBean if enabled cache, clear the cache bean | ||||
func (engine *Engine) ClearCacheBean(bean interface{}, id string) error { | func (engine *Engine) ClearCacheBean(bean interface{}, id string) error { | ||||
v := rValue(bean) | |||||
t := v.Type() | |||||
if t.Kind() != reflect.Struct { | |||||
return errors.New("error params") | |||||
} | |||||
tableName := engine.tbName(v) | |||||
table, err := engine.autoMapType(v) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
cacher := table.Cacher | |||||
if cacher == nil { | |||||
cacher = engine.Cacher | |||||
} | |||||
tableName := engine.TableName(bean) | |||||
cacher := engine.getCacher(tableName) | |||||
if cacher != nil { | if cacher != nil { | ||||
cacher.ClearIds(tableName) | cacher.ClearIds(tableName) | ||||
cacher.DelBean(tableName, id) | cacher.DelBean(tableName, id) | ||||
@@ -1192,21 +1169,8 @@ func (engine *Engine) ClearCacheBean(bean interface{}, id string) error { | |||||
// ClearCache if enabled cache, clear some tables' cache | // ClearCache if enabled cache, clear some tables' cache | ||||
func (engine *Engine) ClearCache(beans ...interface{}) error { | func (engine *Engine) ClearCache(beans ...interface{}) error { | ||||
for _, bean := range beans { | for _, bean := range beans { | ||||
v := rValue(bean) | |||||
t := v.Type() | |||||
if t.Kind() != reflect.Struct { | |||||
return errors.New("error params") | |||||
} | |||||
tableName := engine.tbName(v) | |||||
table, err := engine.autoMapType(v) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
cacher := table.Cacher | |||||
if cacher == nil { | |||||
cacher = engine.Cacher | |||||
} | |||||
tableName := engine.TableName(bean) | |||||
cacher := engine.getCacher(tableName) | |||||
if cacher != nil { | if cacher != nil { | ||||
cacher.ClearIds(tableName) | cacher.ClearIds(tableName) | ||||
cacher.ClearBeans(tableName) | cacher.ClearBeans(tableName) | ||||
@@ -1224,13 +1188,13 @@ func (engine *Engine) Sync(beans ...interface{}) error { | |||||
for _, bean := range beans { | for _, bean := range beans { | ||||
v := rValue(bean) | v := rValue(bean) | ||||
tableName := engine.tbName(v) | |||||
tableNameNoSchema := engine.TableName(bean) | |||||
table, err := engine.autoMapType(v) | table, err := engine.autoMapType(v) | ||||
if err != nil { | if err != nil { | ||||
return err | return err | ||||
} | } | ||||
isExist, err := session.Table(bean).isTableExist(tableName) | |||||
isExist, err := session.Table(bean).isTableExist(tableNameNoSchema) | |||||
if err != nil { | if err != nil { | ||||
return err | return err | ||||
} | } | ||||
@@ -1256,12 +1220,12 @@ func (engine *Engine) Sync(beans ...interface{}) error { | |||||
} | } | ||||
} else { | } else { | ||||
for _, col := range table.Columns() { | for _, col := range table.Columns() { | ||||
isExist, err := engine.dialect.IsColumnExist(tableName, col.Name) | |||||
isExist, err := engine.dialect.IsColumnExist(tableNameNoSchema, col.Name) | |||||
if err != nil { | if err != nil { | ||||
return err | return err | ||||
} | } | ||||
if !isExist { | if !isExist { | ||||
if err := session.statement.setRefValue(v); err != nil { | |||||
if err := session.statement.setRefBean(bean); err != nil { | |||||
return err | return err | ||||
} | } | ||||
err = session.addColumn(col.Name) | err = session.addColumn(col.Name) | ||||
@@ -1272,35 +1236,35 @@ func (engine *Engine) Sync(beans ...interface{}) error { | |||||
} | } | ||||
for name, index := range table.Indexes { | for name, index := range table.Indexes { | ||||
if err := session.statement.setRefValue(v); err != nil { | |||||
if err := session.statement.setRefBean(bean); err != nil { | |||||
return err | return err | ||||
} | } | ||||
if index.Type == core.UniqueType { | if index.Type == core.UniqueType { | ||||
isExist, err := session.isIndexExist2(tableName, index.Cols, true) | |||||
isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, true) | |||||
if err != nil { | if err != nil { | ||||
return err | return err | ||||
} | } | ||||
if !isExist { | if !isExist { | ||||
if err := session.statement.setRefValue(v); err != nil { | |||||
if err := session.statement.setRefBean(bean); err != nil { | |||||
return err | return err | ||||
} | } | ||||
err = session.addUnique(tableName, name) | |||||
err = session.addUnique(tableNameNoSchema, name) | |||||
if err != nil { | if err != nil { | ||||
return err | return err | ||||
} | } | ||||
} | } | ||||
} else if index.Type == core.IndexType { | } else if index.Type == core.IndexType { | ||||
isExist, err := session.isIndexExist2(tableName, index.Cols, false) | |||||
isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, false) | |||||
if err != nil { | if err != nil { | ||||
return err | return err | ||||
} | } | ||||
if !isExist { | if !isExist { | ||||
if err := session.statement.setRefValue(v); err != nil { | |||||
if err := session.statement.setRefBean(bean); err != nil { | |||||
return err | return err | ||||
} | } | ||||
err = session.addIndex(tableName, name) | |||||
err = session.addIndex(tableNameNoSchema, name) | |||||
if err != nil { | if err != nil { | ||||
return err | return err | ||||
} | } | ||||
@@ -1453,6 +1417,13 @@ func (engine *Engine) Find(beans interface{}, condiBeans ...interface{}) error { | |||||
return session.Find(beans, condiBeans...) | return session.Find(beans, condiBeans...) | ||||
} | } | ||||
// FindAndCount find the results and also return the counts | |||||
func (engine *Engine) FindAndCount(rowsSlicePtr interface{}, condiBean ...interface{}) (int64, error) { | |||||
session := engine.NewSession() | |||||
defer session.Close() | |||||
return session.FindAndCount(rowsSlicePtr, condiBean...) | |||||
} | |||||
// Iterate record by record handle records from table, bean's non-empty fields | // Iterate record by record handle records from table, bean's non-empty fields | ||||
// are conditions. | // are conditions. | ||||
func (engine *Engine) Iterate(bean interface{}, fun IterFunc) error { | func (engine *Engine) Iterate(bean interface{}, fun IterFunc) error { | ||||
@@ -1629,6 +1600,11 @@ func (engine *Engine) SetTZDatabase(tz *time.Location) { | |||||
engine.DatabaseTZ = tz | engine.DatabaseTZ = tz | ||||
} | } | ||||
// SetSchema sets the schema of database | |||||
func (engine *Engine) SetSchema(schema string) { | |||||
engine.dialect.URI().Schema = schema | |||||
} | |||||
// Unscoped always disable struct tag "deleted" | // Unscoped always disable struct tag "deleted" | ||||
func (engine *Engine) Unscoped() *Session { | func (engine *Engine) Unscoped() *Session { | ||||
session := engine.NewSession() | session := engine.NewSession() | ||||
@@ -9,6 +9,7 @@ import ( | |||||
"encoding/json" | "encoding/json" | ||||
"fmt" | "fmt" | ||||
"reflect" | "reflect" | ||||
"strings" | |||||
"time" | "time" | ||||
"github.com/go-xorm/builder" | "github.com/go-xorm/builder" | ||||
@@ -51,7 +52,9 @@ func (engine *Engine) buildConds(table *core.Table, bean interface{}, | |||||
fieldValuePtr, err := col.ValueOf(bean) | fieldValuePtr, err := col.ValueOf(bean) | ||||
if err != nil { | if err != nil { | ||||
engine.logger.Error(err) | |||||
if !strings.Contains(err.Error(), "is not valid") { | |||||
engine.logger.Warn(err) | |||||
} | |||||
continue | continue | ||||
} | } | ||||
@@ -0,0 +1,113 @@ | |||||
// Copyright 2018 The Xorm Authors. All rights reserved. | |||||
// Use of this source code is governed by a BSD-style | |||||
// license that can be found in the LICENSE file. | |||||
package xorm | |||||
import ( | |||||
"fmt" | |||||
"reflect" | |||||
"strings" | |||||
"github.com/go-xorm/core" | |||||
) | |||||
// TableNameWithSchema will automatically add schema prefix on table name | |||||
func (engine *Engine) tbNameWithSchema(v string) string { | |||||
// Add schema name as prefix of table name. | |||||
// Only for postgres database. | |||||
if engine.dialect.DBType() == core.POSTGRES && | |||||
engine.dialect.URI().Schema != "" && | |||||
engine.dialect.URI().Schema != postgresPublicSchema && | |||||
strings.Index(v, ".") == -1 { | |||||
return engine.dialect.URI().Schema + "." + v | |||||
} | |||||
return v | |||||
} | |||||
// TableName returns table name with schema prefix if has | |||||
func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string { | |||||
tbName := engine.tbNameNoSchema(bean) | |||||
if len(includeSchema) > 0 && includeSchema[0] { | |||||
tbName = engine.tbNameWithSchema(tbName) | |||||
} | |||||
return tbName | |||||
} | |||||
// tbName get some table's table name | |||||
func (session *Session) tbNameNoSchema(table *core.Table) string { | |||||
if len(session.statement.AltTableName) > 0 { | |||||
return session.statement.AltTableName | |||||
} | |||||
return table.Name | |||||
} | |||||
func (engine *Engine) tbNameForMap(v reflect.Value) string { | |||||
if v.Type().Implements(tpTableName) { | |||||
return v.Interface().(TableName).TableName() | |||||
} | |||||
if v.Kind() == reflect.Ptr { | |||||
v = v.Elem() | |||||
if v.Type().Implements(tpTableName) { | |||||
return v.Interface().(TableName).TableName() | |||||
} | |||||
} | |||||
return engine.TableMapper.Obj2Table(v.Type().Name()) | |||||
} | |||||
func (engine *Engine) tbNameNoSchema(tablename interface{}) string { | |||||
switch tablename.(type) { | |||||
case []string: | |||||
t := tablename.([]string) | |||||
if len(t) > 1 { | |||||
return fmt.Sprintf("%v AS %v", engine.Quote(t[0]), engine.Quote(t[1])) | |||||
} else if len(t) == 1 { | |||||
return engine.Quote(t[0]) | |||||
} | |||||
case []interface{}: | |||||
t := tablename.([]interface{}) | |||||
l := len(t) | |||||
var table string | |||||
if l > 0 { | |||||
f := t[0] | |||||
switch f.(type) { | |||||
case string: | |||||
table = f.(string) | |||||
case TableName: | |||||
table = f.(TableName).TableName() | |||||
default: | |||||
v := rValue(f) | |||||
t := v.Type() | |||||
if t.Kind() == reflect.Struct { | |||||
table = engine.tbNameForMap(v) | |||||
} else { | |||||
table = engine.Quote(fmt.Sprintf("%v", f)) | |||||
} | |||||
} | |||||
} | |||||
if l > 1 { | |||||
return fmt.Sprintf("%v AS %v", engine.Quote(table), | |||||
engine.Quote(fmt.Sprintf("%v", t[1]))) | |||||
} else if l == 1 { | |||||
return engine.Quote(table) | |||||
} | |||||
case TableName: | |||||
return tablename.(TableName).TableName() | |||||
case string: | |||||
return tablename.(string) | |||||
case reflect.Value: | |||||
v := tablename.(reflect.Value) | |||||
return engine.tbNameForMap(v) | |||||
default: | |||||
v := rValue(tablename) | |||||
t := v.Type() | |||||
if t.Kind() == reflect.Struct { | |||||
return engine.tbNameForMap(v) | |||||
} | |||||
return engine.Quote(fmt.Sprintf("%v", tablename)) | |||||
} | |||||
return "" | |||||
} |
@@ -6,23 +6,44 @@ package xorm | |||||
import ( | import ( | ||||
"errors" | "errors" | ||||
"fmt" | |||||
) | ) | ||||
var ( | var ( | ||||
// ErrParamsType params error | // ErrParamsType params error | ||||
ErrParamsType = errors.New("Params type error") | ErrParamsType = errors.New("Params type error") | ||||
// ErrTableNotFound table not found error | // ErrTableNotFound table not found error | ||||
ErrTableNotFound = errors.New("Not found table") | |||||
ErrTableNotFound = errors.New("Table not found") | |||||
// ErrUnSupportedType unsupported error | // ErrUnSupportedType unsupported error | ||||
ErrUnSupportedType = errors.New("Unsupported type error") | ErrUnSupportedType = errors.New("Unsupported type error") | ||||
// ErrNotExist record is not exist error | |||||
ErrNotExist = errors.New("Not exist error") | |||||
// ErrNotExist record does not exist error | |||||
ErrNotExist = errors.New("Record does not exist") | |||||
// ErrCacheFailed cache failed error | // ErrCacheFailed cache failed error | ||||
ErrCacheFailed = errors.New("Cache failed") | ErrCacheFailed = errors.New("Cache failed") | ||||
// ErrNeedDeletedCond delete needs less one condition error | // ErrNeedDeletedCond delete needs less one condition error | ||||
ErrNeedDeletedCond = errors.New("Delete need at least one condition") | |||||
ErrNeedDeletedCond = errors.New("Delete action needs at least one condition") | |||||
// ErrNotImplemented not implemented | // ErrNotImplemented not implemented | ||||
ErrNotImplemented = errors.New("Not implemented") | ErrNotImplemented = errors.New("Not implemented") | ||||
// ErrConditionType condition type unsupported | // ErrConditionType condition type unsupported | ||||
ErrConditionType = errors.New("Unsupported conditon type") | |||||
ErrConditionType = errors.New("Unsupported condition type") | |||||
) | ) | ||||
// ErrFieldIsNotExist columns does not exist | |||||
type ErrFieldIsNotExist struct { | |||||
FieldName string | |||||
TableName string | |||||
} | |||||
func (e ErrFieldIsNotExist) Error() string { | |||||
return fmt.Sprintf("field %s is not valid on table %s", e.FieldName, e.TableName) | |||||
} | |||||
// ErrFieldIsNotValid is not valid | |||||
type ErrFieldIsNotValid struct { | |||||
FieldName string | |||||
TableName string | |||||
} | |||||
func (e ErrFieldIsNotValid) Error() string { | |||||
return fmt.Sprintf("field %s is not valid on table %s", e.FieldName, e.TableName) | |||||
} |
@@ -11,7 +11,6 @@ import ( | |||||
"sort" | "sort" | ||||
"strconv" | "strconv" | ||||
"strings" | "strings" | ||||
"time" | |||||
"github.com/go-xorm/core" | "github.com/go-xorm/core" | ||||
) | ) | ||||
@@ -293,19 +292,6 @@ func structName(v reflect.Type) string { | |||||
return v.Name() | return v.Name() | ||||
} | } | ||||
func col2NewCols(columns ...string) []string { | |||||
newColumns := make([]string, 0, len(columns)) | |||||
for _, col := range columns { | |||||
col = strings.Replace(col, "`", "", -1) | |||||
col = strings.Replace(col, `"`, "", -1) | |||||
ccols := strings.Split(col, ",") | |||||
for _, c := range ccols { | |||||
newColumns = append(newColumns, strings.TrimSpace(c)) | |||||
} | |||||
} | |||||
return newColumns | |||||
} | |||||
func sliceEq(left, right []string) bool { | func sliceEq(left, right []string) bool { | ||||
if len(left) != len(right) { | if len(left) != len(right) { | ||||
return false | return false | ||||
@@ -320,154 +306,6 @@ func sliceEq(left, right []string) bool { | |||||
return true | return true | ||||
} | } | ||||
func setColumnInt(bean interface{}, col *core.Column, t int64) { | |||||
v, err := col.ValueOf(bean) | |||||
if err != nil { | |||||
return | |||||
} | |||||
if v.CanSet() { | |||||
switch v.Type().Kind() { | |||||
case reflect.Int, reflect.Int64, reflect.Int32: | |||||
v.SetInt(t) | |||||
case reflect.Uint, reflect.Uint64, reflect.Uint32: | |||||
v.SetUint(uint64(t)) | |||||
} | |||||
} | |||||
} | |||||
func setColumnTime(bean interface{}, col *core.Column, t time.Time) { | |||||
v, err := col.ValueOf(bean) | |||||
if err != nil { | |||||
return | |||||
} | |||||
if v.CanSet() { | |||||
switch v.Type().Kind() { | |||||
case reflect.Struct: | |||||
v.Set(reflect.ValueOf(t).Convert(v.Type())) | |||||
case reflect.Int, reflect.Int64, reflect.Int32: | |||||
v.SetInt(t.Unix()) | |||||
case reflect.Uint, reflect.Uint64, reflect.Uint32: | |||||
v.SetUint(uint64(t.Unix())) | |||||
} | |||||
} | |||||
} | |||||
func genCols(table *core.Table, session *Session, bean interface{}, useCol bool, includeQuote bool) ([]string, []interface{}, error) { | |||||
colNames := make([]string, 0, len(table.ColumnsSeq())) | |||||
args := make([]interface{}, 0, len(table.ColumnsSeq())) | |||||
for _, col := range table.Columns() { | |||||
if useCol && !col.IsVersion && !col.IsCreated && !col.IsUpdated { | |||||
if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok { | |||||
continue | |||||
} | |||||
} | |||||
if col.MapType == core.ONLYFROMDB { | |||||
continue | |||||
} | |||||
fieldValuePtr, err := col.ValueOf(bean) | |||||
if err != nil { | |||||
return nil, nil, err | |||||
} | |||||
fieldValue := *fieldValuePtr | |||||
if col.IsAutoIncrement { | |||||
switch fieldValue.Type().Kind() { | |||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64: | |||||
if fieldValue.Int() == 0 { | |||||
continue | |||||
} | |||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64: | |||||
if fieldValue.Uint() == 0 { | |||||
continue | |||||
} | |||||
case reflect.String: | |||||
if len(fieldValue.String()) == 0 { | |||||
continue | |||||
} | |||||
case reflect.Ptr: | |||||
if fieldValue.Pointer() == 0 { | |||||
continue | |||||
} | |||||
} | |||||
} | |||||
if col.IsDeleted { | |||||
continue | |||||
} | |||||
if session.statement.ColumnStr != "" { | |||||
if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok { | |||||
continue | |||||
} else if _, ok := session.statement.incrColumns[col.Name]; ok { | |||||
continue | |||||
} else if _, ok := session.statement.decrColumns[col.Name]; ok { | |||||
continue | |||||
} | |||||
} | |||||
if session.statement.OmitStr != "" { | |||||
if _, ok := getFlagForColumn(session.statement.columnMap, col); ok { | |||||
continue | |||||
} | |||||
} | |||||
// !evalphobia! set fieldValue as nil when column is nullable and zero-value | |||||
if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok { | |||||
if col.Nullable && isZero(fieldValue.Interface()) { | |||||
var nilValue *int | |||||
fieldValue = reflect.ValueOf(nilValue) | |||||
} | |||||
} | |||||
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ { | |||||
// if time is non-empty, then set to auto time | |||||
val, t := session.engine.nowTime(col) | |||||
args = append(args, val) | |||||
var colName = col.Name | |||||
session.afterClosures = append(session.afterClosures, func(bean interface{}) { | |||||
col := table.GetColumn(colName) | |||||
setColumnTime(bean, col, t) | |||||
}) | |||||
} else if col.IsVersion && session.statement.checkVersion { | |||||
args = append(args, 1) | |||||
} else { | |||||
arg, err := session.value2Interface(col, fieldValue) | |||||
if err != nil { | |||||
return colNames, args, err | |||||
} | |||||
args = append(args, arg) | |||||
} | |||||
if includeQuote { | |||||
colNames = append(colNames, session.engine.Quote(col.Name)+" = ?") | |||||
} else { | |||||
colNames = append(colNames, col.Name) | |||||
} | |||||
} | |||||
return colNames, args, nil | |||||
} | |||||
func indexName(tableName, idxName string) string { | func indexName(tableName, idxName string) string { | ||||
return fmt.Sprintf("IDX_%v_%v", tableName, idxName) | return fmt.Sprintf("IDX_%v_%v", tableName, idxName) | ||||
} | } | ||||
func getFlagForColumn(m map[string]bool, col *core.Column) (val bool, has bool) { | |||||
if len(m) == 0 { | |||||
return false, false | |||||
} | |||||
n := len(col.Name) | |||||
for mk := range m { | |||||
if len(mk) != n { | |||||
continue | |||||
} | |||||
if strings.EqualFold(mk, col.Name) { | |||||
return m[mk], true | |||||
} | |||||
} | |||||
return false, false | |||||
} |
@@ -30,6 +30,7 @@ type Interface interface { | |||||
Exec(string, ...interface{}) (sql.Result, error) | Exec(string, ...interface{}) (sql.Result, error) | ||||
Exist(bean ...interface{}) (bool, error) | Exist(bean ...interface{}) (bool, error) | ||||
Find(interface{}, ...interface{}) error | Find(interface{}, ...interface{}) error | ||||
FindAndCount(interface{}, ...interface{}) (int64, error) | |||||
Get(interface{}) (bool, error) | Get(interface{}) (bool, error) | ||||
GroupBy(keys string) *Session | GroupBy(keys string) *Session | ||||
ID(interface{}) *Session | ID(interface{}) *Session | ||||
@@ -41,6 +42,7 @@ type Interface interface { | |||||
IsTableExist(beanOrTableName interface{}) (bool, error) | IsTableExist(beanOrTableName interface{}) (bool, error) | ||||
Iterate(interface{}, IterFunc) error | Iterate(interface{}, IterFunc) error | ||||
Limit(int, ...int) *Session | Limit(int, ...int) *Session | ||||
MustCols(columns ...string) *Session | |||||
NoAutoCondition(...bool) *Session | NoAutoCondition(...bool) *Session | ||||
NotIn(string, ...interface{}) *Session | NotIn(string, ...interface{}) *Session | ||||
Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session | Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session | ||||
@@ -75,6 +77,7 @@ type EngineInterface interface { | |||||
Dialect() core.Dialect | Dialect() core.Dialect | ||||
DropTables(...interface{}) error | DropTables(...interface{}) error | ||||
DumpAllToFile(fp string, tp ...core.DbType) error | DumpAllToFile(fp string, tp ...core.DbType) error | ||||
GetCacher(string) core.Cacher | |||||
GetColumnMapper() core.IMapper | GetColumnMapper() core.IMapper | ||||
GetDefaultCacher() core.Cacher | GetDefaultCacher() core.Cacher | ||||
GetTableMapper() core.IMapper | GetTableMapper() core.IMapper | ||||
@@ -83,9 +86,11 @@ type EngineInterface interface { | |||||
NewSession() *Session | NewSession() *Session | ||||
NoAutoTime() *Session | NoAutoTime() *Session | ||||
Quote(string) string | Quote(string) string | ||||
SetCacher(string, core.Cacher) | |||||
SetDefaultCacher(core.Cacher) | SetDefaultCacher(core.Cacher) | ||||
SetLogLevel(core.LogLevel) | SetLogLevel(core.LogLevel) | ||||
SetMapper(core.IMapper) | SetMapper(core.IMapper) | ||||
SetSchema(string) | |||||
SetTZDatabase(tz *time.Location) | SetTZDatabase(tz *time.Location) | ||||
SetTZLocation(tz *time.Location) | SetTZLocation(tz *time.Location) | ||||
ShowSQL(show ...bool) | ShowSQL(show ...bool) | ||||
@@ -93,6 +98,7 @@ type EngineInterface interface { | |||||
Sync2(...interface{}) error | Sync2(...interface{}) error | ||||
StoreEngine(storeEngine string) *Session | StoreEngine(storeEngine string) *Session | ||||
TableInfo(bean interface{}) *Table | TableInfo(bean interface{}) *Table | ||||
TableName(interface{}, ...bool) string | |||||
UnMapType(reflect.Type) | UnMapType(reflect.Type) | ||||
} | } | ||||
@@ -32,7 +32,7 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { | |||||
var args []interface{} | var args []interface{} | ||||
var err error | var err error | ||||
if err = rows.session.statement.setRefValue(rValue(bean)); err != nil { | |||||
if err = rows.session.statement.setRefBean(bean); err != nil { | |||||
return nil, err | return nil, err | ||||
} | } | ||||
@@ -94,8 +94,7 @@ func (rows *Rows) Scan(bean interface{}) error { | |||||
return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType) | return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType) | ||||
} | } | ||||
dataStruct := rValue(bean) | |||||
if err := rows.session.statement.setRefValue(dataStruct); err != nil { | |||||
if err := rows.session.statement.setRefBean(bean); err != nil { | |||||
return err | return err | ||||
} | } | ||||
@@ -104,6 +103,7 @@ func (rows *Rows) Scan(bean interface{}) error { | |||||
return err | return err | ||||
} | } | ||||
dataStruct := rValue(bean) | |||||
_, err = rows.session.slice2Bean(scanResults, rows.fields, bean, &dataStruct, rows.session.statement.RefTable) | _, err = rows.session.slice2Bean(scanResults, rows.fields, bean, &dataStruct, rows.session.statement.RefTable) | ||||
if err != nil { | if err != nil { | ||||
return err | return err | ||||
@@ -278,24 +278,22 @@ func (session *Session) doPrepare(db *core.DB, sqlStr string) (stmt *core.Stmt, | |||||
return | return | ||||
} | } | ||||
func (session *Session) getField(dataStruct *reflect.Value, key string, table *core.Table, idx int) *reflect.Value { | |||||
func (session *Session) getField(dataStruct *reflect.Value, key string, table *core.Table, idx int) (*reflect.Value, error) { | |||||
var col *core.Column | var col *core.Column | ||||
if col = table.GetColumnIdx(key, idx); col == nil { | if col = table.GetColumnIdx(key, idx); col == nil { | ||||
//session.engine.logger.Warnf("table %v has no column %v. %v", table.Name, key, table.ColumnsSeq()) | |||||
return nil | |||||
return nil, ErrFieldIsNotExist{key, table.Name} | |||||
} | } | ||||
fieldValue, err := col.ValueOfV(dataStruct) | fieldValue, err := col.ValueOfV(dataStruct) | ||||
if err != nil { | if err != nil { | ||||
session.engine.logger.Error(err) | |||||
return nil | |||||
return nil, err | |||||
} | } | ||||
if !fieldValue.IsValid() || !fieldValue.CanSet() { | if !fieldValue.IsValid() || !fieldValue.CanSet() { | ||||
session.engine.logger.Warnf("table %v's column %v is not valid or cannot set", table.Name, key) | |||||
return nil | |||||
return nil, ErrFieldIsNotValid{key, table.Name} | |||||
} | } | ||||
return fieldValue | |||||
return fieldValue, nil | |||||
} | } | ||||
// Cell cell is a result of one column field | // Cell cell is a result of one column field | ||||
@@ -407,409 +405,417 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b | |||||
} | } | ||||
tempMap[lKey] = idx | tempMap[lKey] = idx | ||||
if fieldValue := session.getField(dataStruct, key, table, idx); fieldValue != nil { | |||||
rawValue := reflect.Indirect(reflect.ValueOf(scanResults[ii])) | |||||
// if row is null then ignore | |||||
if rawValue.Interface() == nil { | |||||
continue | |||||
fieldValue, err := session.getField(dataStruct, key, table, idx) | |||||
if err != nil { | |||||
if !strings.Contains(err.Error(), "is not valid") { | |||||
session.engine.logger.Warn(err) | |||||
} | } | ||||
continue | |||||
} | |||||
if fieldValue == nil { | |||||
continue | |||||
} | |||||
rawValue := reflect.Indirect(reflect.ValueOf(scanResults[ii])) | |||||
if fieldValue.CanAddr() { | |||||
if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok { | |||||
if data, err := value2Bytes(&rawValue); err == nil { | |||||
if err := structConvert.FromDB(data); err != nil { | |||||
return nil, err | |||||
} | |||||
} else { | |||||
return nil, err | |||||
} | |||||
continue | |||||
} | |||||
} | |||||
// if row is null then ignore | |||||
if rawValue.Interface() == nil { | |||||
continue | |||||
} | |||||
if _, ok := fieldValue.Interface().(core.Conversion); ok { | |||||
if fieldValue.CanAddr() { | |||||
if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok { | |||||
if data, err := value2Bytes(&rawValue); err == nil { | if data, err := value2Bytes(&rawValue); err == nil { | ||||
if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { | |||||
fieldValue.Set(reflect.New(fieldValue.Type().Elem())) | |||||
if err := structConvert.FromDB(data); err != nil { | |||||
return nil, err | |||||
} | } | ||||
fieldValue.Interface().(core.Conversion).FromDB(data) | |||||
} else { | } else { | ||||
return nil, err | return nil, err | ||||
} | } | ||||
continue | continue | ||||
} | } | ||||
} | |||||
rawValueType := reflect.TypeOf(rawValue.Interface()) | |||||
vv := reflect.ValueOf(rawValue.Interface()) | |||||
col := table.GetColumnIdx(key, idx) | |||||
if col.IsPrimaryKey { | |||||
pk = append(pk, rawValue.Interface()) | |||||
if _, ok := fieldValue.Interface().(core.Conversion); ok { | |||||
if data, err := value2Bytes(&rawValue); err == nil { | |||||
if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { | |||||
fieldValue.Set(reflect.New(fieldValue.Type().Elem())) | |||||
} | |||||
fieldValue.Interface().(core.Conversion).FromDB(data) | |||||
} else { | |||||
return nil, err | |||||
} | } | ||||
fieldType := fieldValue.Type() | |||||
hasAssigned := false | |||||
continue | |||||
} | |||||
if col.SQLType.IsJson() { | |||||
var bs []byte | |||||
if rawValueType.Kind() == reflect.String { | |||||
bs = []byte(vv.String()) | |||||
} else if rawValueType.ConvertibleTo(core.BytesType) { | |||||
bs = vv.Bytes() | |||||
} else { | |||||
return nil, fmt.Errorf("unsupported database data type: %s %v", key, rawValueType.Kind()) | |||||
} | |||||
rawValueType := reflect.TypeOf(rawValue.Interface()) | |||||
vv := reflect.ValueOf(rawValue.Interface()) | |||||
col := table.GetColumnIdx(key, idx) | |||||
if col.IsPrimaryKey { | |||||
pk = append(pk, rawValue.Interface()) | |||||
} | |||||
fieldType := fieldValue.Type() | |||||
hasAssigned := false | |||||
if col.SQLType.IsJson() { | |||||
var bs []byte | |||||
if rawValueType.Kind() == reflect.String { | |||||
bs = []byte(vv.String()) | |||||
} else if rawValueType.ConvertibleTo(core.BytesType) { | |||||
bs = vv.Bytes() | |||||
} else { | |||||
return nil, fmt.Errorf("unsupported database data type: %s %v", key, rawValueType.Kind()) | |||||
} | |||||
hasAssigned = true | |||||
hasAssigned = true | |||||
if len(bs) > 0 { | |||||
if fieldType.Kind() == reflect.String { | |||||
fieldValue.SetString(string(bs)) | |||||
continue | |||||
if len(bs) > 0 { | |||||
if fieldType.Kind() == reflect.String { | |||||
fieldValue.SetString(string(bs)) | |||||
continue | |||||
} | |||||
if fieldValue.CanAddr() { | |||||
err := json.Unmarshal(bs, fieldValue.Addr().Interface()) | |||||
if err != nil { | |||||
return nil, err | |||||
} | } | ||||
if fieldValue.CanAddr() { | |||||
err := json.Unmarshal(bs, fieldValue.Addr().Interface()) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
} else { | |||||
x := reflect.New(fieldType) | |||||
err := json.Unmarshal(bs, x.Interface()) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
fieldValue.Set(x.Elem()) | |||||
} else { | |||||
x := reflect.New(fieldType) | |||||
err := json.Unmarshal(bs, x.Interface()) | |||||
if err != nil { | |||||
return nil, err | |||||
} | } | ||||
fieldValue.Set(x.Elem()) | |||||
} | } | ||||
continue | |||||
} | } | ||||
switch fieldType.Kind() { | |||||
case reflect.Complex64, reflect.Complex128: | |||||
// TODO: reimplement this | |||||
var bs []byte | |||||
if rawValueType.Kind() == reflect.String { | |||||
bs = []byte(vv.String()) | |||||
} else if rawValueType.ConvertibleTo(core.BytesType) { | |||||
bs = vv.Bytes() | |||||
} | |||||
continue | |||||
} | |||||
hasAssigned = true | |||||
if len(bs) > 0 { | |||||
if fieldValue.CanAddr() { | |||||
err := json.Unmarshal(bs, fieldValue.Addr().Interface()) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
} else { | |||||
x := reflect.New(fieldType) | |||||
err := json.Unmarshal(bs, x.Interface()) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
fieldValue.Set(x.Elem()) | |||||
switch fieldType.Kind() { | |||||
case reflect.Complex64, reflect.Complex128: | |||||
// TODO: reimplement this | |||||
var bs []byte | |||||
if rawValueType.Kind() == reflect.String { | |||||
bs = []byte(vv.String()) | |||||
} else if rawValueType.ConvertibleTo(core.BytesType) { | |||||
bs = vv.Bytes() | |||||
} | |||||
hasAssigned = true | |||||
if len(bs) > 0 { | |||||
if fieldValue.CanAddr() { | |||||
err := json.Unmarshal(bs, fieldValue.Addr().Interface()) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
} else { | |||||
x := reflect.New(fieldType) | |||||
err := json.Unmarshal(bs, x.Interface()) | |||||
if err != nil { | |||||
return nil, err | |||||
} | } | ||||
fieldValue.Set(x.Elem()) | |||||
} | } | ||||
} | |||||
case reflect.Slice, reflect.Array: | |||||
switch rawValueType.Kind() { | |||||
case reflect.Slice, reflect.Array: | case reflect.Slice, reflect.Array: | ||||
switch rawValueType.Kind() { | |||||
case reflect.Slice, reflect.Array: | |||||
switch rawValueType.Elem().Kind() { | |||||
case reflect.Uint8: | |||||
if fieldType.Elem().Kind() == reflect.Uint8 { | |||||
hasAssigned = true | |||||
if col.SQLType.IsText() { | |||||
x := reflect.New(fieldType) | |||||
err := json.Unmarshal(vv.Bytes(), x.Interface()) | |||||
if err != nil { | |||||
return nil, err | |||||
switch rawValueType.Elem().Kind() { | |||||
case reflect.Uint8: | |||||
if fieldType.Elem().Kind() == reflect.Uint8 { | |||||
hasAssigned = true | |||||
if col.SQLType.IsText() { | |||||
x := reflect.New(fieldType) | |||||
err := json.Unmarshal(vv.Bytes(), x.Interface()) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
fieldValue.Set(x.Elem()) | |||||
} else { | |||||
if fieldValue.Len() > 0 { | |||||
for i := 0; i < fieldValue.Len(); i++ { | |||||
if i < vv.Len() { | |||||
fieldValue.Index(i).Set(vv.Index(i)) | |||||
} | |||||
} | } | ||||
fieldValue.Set(x.Elem()) | |||||
} else { | } else { | ||||
if fieldValue.Len() > 0 { | |||||
for i := 0; i < fieldValue.Len(); i++ { | |||||
if i < vv.Len() { | |||||
fieldValue.Index(i).Set(vv.Index(i)) | |||||
} | |||||
} | |||||
} else { | |||||
for i := 0; i < vv.Len(); i++ { | |||||
fieldValue.Set(reflect.Append(*fieldValue, vv.Index(i))) | |||||
} | |||||
for i := 0; i < vv.Len(); i++ { | |||||
fieldValue.Set(reflect.Append(*fieldValue, vv.Index(i))) | |||||
} | } | ||||
} | } | ||||
} | } | ||||
} | } | ||||
} | } | ||||
case reflect.String: | |||||
if rawValueType.Kind() == reflect.String { | |||||
hasAssigned = true | |||||
fieldValue.SetString(vv.String()) | |||||
} | |||||
case reflect.Bool: | |||||
if rawValueType.Kind() == reflect.Bool { | |||||
hasAssigned = true | |||||
fieldValue.SetBool(vv.Bool()) | |||||
} | |||||
} | |||||
case reflect.String: | |||||
if rawValueType.Kind() == reflect.String { | |||||
hasAssigned = true | |||||
fieldValue.SetString(vv.String()) | |||||
} | |||||
case reflect.Bool: | |||||
if rawValueType.Kind() == reflect.Bool { | |||||
hasAssigned = true | |||||
fieldValue.SetBool(vv.Bool()) | |||||
} | |||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | |||||
switch rawValueType.Kind() { | |||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | ||||
switch rawValueType.Kind() { | |||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | |||||
hasAssigned = true | |||||
fieldValue.SetInt(vv.Int()) | |||||
} | |||||
hasAssigned = true | |||||
fieldValue.SetInt(vv.Int()) | |||||
} | |||||
case reflect.Float32, reflect.Float64: | |||||
switch rawValueType.Kind() { | |||||
case reflect.Float32, reflect.Float64: | case reflect.Float32, reflect.Float64: | ||||
switch rawValueType.Kind() { | |||||
case reflect.Float32, reflect.Float64: | |||||
hasAssigned = true | |||||
fieldValue.SetFloat(vv.Float()) | |||||
} | |||||
hasAssigned = true | |||||
fieldValue.SetFloat(vv.Float()) | |||||
} | |||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: | |||||
switch rawValueType.Kind() { | |||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: | case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: | ||||
switch rawValueType.Kind() { | |||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: | |||||
hasAssigned = true | |||||
fieldValue.SetUint(vv.Uint()) | |||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | |||||
hasAssigned = true | |||||
fieldValue.SetUint(uint64(vv.Int())) | |||||
hasAssigned = true | |||||
fieldValue.SetUint(vv.Uint()) | |||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | |||||
hasAssigned = true | |||||
fieldValue.SetUint(uint64(vv.Int())) | |||||
} | |||||
case reflect.Struct: | |||||
if fieldType.ConvertibleTo(core.TimeType) { | |||||
dbTZ := session.engine.DatabaseTZ | |||||
if col.TimeZone != nil { | |||||
dbTZ = col.TimeZone | |||||
} | } | ||||
case reflect.Struct: | |||||
if fieldType.ConvertibleTo(core.TimeType) { | |||||
dbTZ := session.engine.DatabaseTZ | |||||
if col.TimeZone != nil { | |||||
dbTZ = col.TimeZone | |||||
} | |||||
if rawValueType == core.TimeType { | |||||
hasAssigned = true | |||||
t := vv.Convert(core.TimeType).Interface().(time.Time) | |||||
z, _ := t.Zone() | |||||
// set new location if database don't save timezone or give an incorrect timezone | |||||
if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbTZ.String() { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location | |||||
session.engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location()) | |||||
t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), | |||||
t.Minute(), t.Second(), t.Nanosecond(), dbTZ) | |||||
} | |||||
if rawValueType == core.TimeType { | |||||
hasAssigned = true | |||||
t = t.In(session.engine.TZLocation) | |||||
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) | |||||
} else if rawValueType == core.IntType || rawValueType == core.Int64Type || | |||||
rawValueType == core.Int32Type { | |||||
hasAssigned = true | |||||
t := vv.Convert(core.TimeType).Interface().(time.Time) | |||||
t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation) | |||||
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) | |||||
} else { | |||||
if d, ok := vv.Interface().([]uint8); ok { | |||||
hasAssigned = true | |||||
t, err := session.byte2Time(col, d) | |||||
if err != nil { | |||||
session.engine.logger.Error("byte2Time error:", err.Error()) | |||||
hasAssigned = false | |||||
} else { | |||||
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) | |||||
} | |||||
} else if d, ok := vv.Interface().(string); ok { | |||||
hasAssigned = true | |||||
t, err := session.str2Time(col, d) | |||||
if err != nil { | |||||
session.engine.logger.Error("byte2Time error:", err.Error()) | |||||
hasAssigned = false | |||||
} else { | |||||
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) | |||||
} | |||||
} else { | |||||
return nil, fmt.Errorf("rawValueType is %v, value is %v", rawValueType, vv.Interface()) | |||||
} | |||||
} | |||||
} else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { | |||||
// !<winxxp>! 增加支持sql.Scanner接口的结构,如sql.NullString | |||||
hasAssigned = true | |||||
if err := nulVal.Scan(vv.Interface()); err != nil { | |||||
session.engine.logger.Error("sql.Sanner error:", err.Error()) | |||||
hasAssigned = false | |||||
} | |||||
} else if col.SQLType.IsJson() { | |||||
if rawValueType.Kind() == reflect.String { | |||||
hasAssigned = true | |||||
x := reflect.New(fieldType) | |||||
if len([]byte(vv.String())) > 0 { | |||||
err := json.Unmarshal([]byte(vv.String()), x.Interface()) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
fieldValue.Set(x.Elem()) | |||||
} | |||||
} else if rawValueType.Kind() == reflect.Slice { | |||||
hasAssigned = true | |||||
x := reflect.New(fieldType) | |||||
if len(vv.Bytes()) > 0 { | |||||
err := json.Unmarshal(vv.Bytes(), x.Interface()) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
fieldValue.Set(x.Elem()) | |||||
} | |||||
} | |||||
} else if session.statement.UseCascade { | |||||
table, err := session.engine.autoMapType(*fieldValue) | |||||
if err != nil { | |||||
return nil, err | |||||
z, _ := t.Zone() | |||||
// set new location if database don't save timezone or give an incorrect timezone | |||||
if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbTZ.String() { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location | |||||
session.engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location()) | |||||
t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), | |||||
t.Minute(), t.Second(), t.Nanosecond(), dbTZ) | |||||
} | } | ||||
t = t.In(session.engine.TZLocation) | |||||
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) | |||||
} else if rawValueType == core.IntType || rawValueType == core.Int64Type || | |||||
rawValueType == core.Int32Type { | |||||
hasAssigned = true | hasAssigned = true | ||||
if len(table.PrimaryKeys) != 1 { | |||||
return nil, errors.New("unsupported non or composited primary key cascade") | |||||
} | |||||
var pk = make(core.PK, len(table.PrimaryKeys)) | |||||
pk[0], err = asKind(vv, rawValueType) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
if !isPKZero(pk) { | |||||
// !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch | |||||
// however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne | |||||
// property to be fetched lazily | |||||
structInter := reflect.New(fieldValue.Type()) | |||||
has, err := session.ID(pk).NoCascade().get(structInter.Interface()) | |||||
t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation) | |||||
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) | |||||
} else { | |||||
if d, ok := vv.Interface().([]uint8); ok { | |||||
hasAssigned = true | |||||
t, err := session.byte2Time(col, d) | |||||
if err != nil { | if err != nil { | ||||
return nil, err | |||||
session.engine.logger.Error("byte2Time error:", err.Error()) | |||||
hasAssigned = false | |||||
} else { | |||||
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) | |||||
} | } | ||||
if has { | |||||
fieldValue.Set(structInter.Elem()) | |||||
} else if d, ok := vv.Interface().(string); ok { | |||||
hasAssigned = true | |||||
t, err := session.str2Time(col, d) | |||||
if err != nil { | |||||
session.engine.logger.Error("byte2Time error:", err.Error()) | |||||
hasAssigned = false | |||||
} else { | } else { | ||||
return nil, errors.New("cascade obj is not exist") | |||||
fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) | |||||
} | } | ||||
} else { | |||||
return nil, fmt.Errorf("rawValueType is %v, value is %v", rawValueType, vv.Interface()) | |||||
} | } | ||||
} | } | ||||
case reflect.Ptr: | |||||
// !nashtsai! TODO merge duplicated codes above | |||||
switch fieldType { | |||||
// following types case matching ptr's native type, therefore assign ptr directly | |||||
case core.PtrStringType: | |||||
if rawValueType.Kind() == reflect.String { | |||||
x := vv.String() | |||||
hasAssigned = true | |||||
fieldValue.Set(reflect.ValueOf(&x)) | |||||
} | |||||
case core.PtrBoolType: | |||||
if rawValueType.Kind() == reflect.Bool { | |||||
x := vv.Bool() | |||||
hasAssigned = true | |||||
fieldValue.Set(reflect.ValueOf(&x)) | |||||
} | |||||
case core.PtrTimeType: | |||||
if rawValueType == core.PtrTimeType { | |||||
hasAssigned = true | |||||
var x = rawValue.Interface().(time.Time) | |||||
fieldValue.Set(reflect.ValueOf(&x)) | |||||
} | |||||
case core.PtrFloat64Type: | |||||
if rawValueType.Kind() == reflect.Float64 { | |||||
x := vv.Float() | |||||
hasAssigned = true | |||||
fieldValue.Set(reflect.ValueOf(&x)) | |||||
} | |||||
case core.PtrUint64Type: | |||||
if rawValueType.Kind() == reflect.Int64 { | |||||
var x = uint64(vv.Int()) | |||||
hasAssigned = true | |||||
fieldValue.Set(reflect.ValueOf(&x)) | |||||
} | |||||
case core.PtrInt64Type: | |||||
if rawValueType.Kind() == reflect.Int64 { | |||||
x := vv.Int() | |||||
hasAssigned = true | |||||
fieldValue.Set(reflect.ValueOf(&x)) | |||||
} | |||||
case core.PtrFloat32Type: | |||||
if rawValueType.Kind() == reflect.Float64 { | |||||
var x = float32(vv.Float()) | |||||
hasAssigned = true | |||||
fieldValue.Set(reflect.ValueOf(&x)) | |||||
} | |||||
case core.PtrIntType: | |||||
if rawValueType.Kind() == reflect.Int64 { | |||||
var x = int(vv.Int()) | |||||
hasAssigned = true | |||||
fieldValue.Set(reflect.ValueOf(&x)) | |||||
} | |||||
case core.PtrInt32Type: | |||||
if rawValueType.Kind() == reflect.Int64 { | |||||
var x = int32(vv.Int()) | |||||
hasAssigned = true | |||||
fieldValue.Set(reflect.ValueOf(&x)) | |||||
} | |||||
case core.PtrInt8Type: | |||||
if rawValueType.Kind() == reflect.Int64 { | |||||
var x = int8(vv.Int()) | |||||
hasAssigned = true | |||||
fieldValue.Set(reflect.ValueOf(&x)) | |||||
} | |||||
case core.PtrInt16Type: | |||||
if rawValueType.Kind() == reflect.Int64 { | |||||
var x = int16(vv.Int()) | |||||
hasAssigned = true | |||||
fieldValue.Set(reflect.ValueOf(&x)) | |||||
} | |||||
case core.PtrUintType: | |||||
if rawValueType.Kind() == reflect.Int64 { | |||||
var x = uint(vv.Int()) | |||||
hasAssigned = true | |||||
fieldValue.Set(reflect.ValueOf(&x)) | |||||
} | |||||
case core.PtrUint32Type: | |||||
if rawValueType.Kind() == reflect.Int64 { | |||||
var x = uint32(vv.Int()) | |||||
hasAssigned = true | |||||
fieldValue.Set(reflect.ValueOf(&x)) | |||||
} | |||||
case core.Uint8Type: | |||||
if rawValueType.Kind() == reflect.Int64 { | |||||
var x = uint8(vv.Int()) | |||||
hasAssigned = true | |||||
fieldValue.Set(reflect.ValueOf(&x)) | |||||
} | |||||
case core.Uint16Type: | |||||
if rawValueType.Kind() == reflect.Int64 { | |||||
var x = uint16(vv.Int()) | |||||
hasAssigned = true | |||||
fieldValue.Set(reflect.ValueOf(&x)) | |||||
} | |||||
case core.Complex64Type: | |||||
var x complex64 | |||||
} else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { | |||||
// !<winxxp>! 增加支持sql.Scanner接口的结构,如sql.NullString | |||||
hasAssigned = true | |||||
if err := nulVal.Scan(vv.Interface()); err != nil { | |||||
session.engine.logger.Error("sql.Sanner error:", err.Error()) | |||||
hasAssigned = false | |||||
} | |||||
} else if col.SQLType.IsJson() { | |||||
if rawValueType.Kind() == reflect.String { | |||||
hasAssigned = true | |||||
x := reflect.New(fieldType) | |||||
if len([]byte(vv.String())) > 0 { | if len([]byte(vv.String())) > 0 { | ||||
err := json.Unmarshal([]byte(vv.String()), &x) | |||||
err := json.Unmarshal([]byte(vv.String()), x.Interface()) | |||||
if err != nil { | if err != nil { | ||||
return nil, err | return nil, err | ||||
} | } | ||||
fieldValue.Set(reflect.ValueOf(&x)) | |||||
fieldValue.Set(x.Elem()) | |||||
} | } | ||||
} else if rawValueType.Kind() == reflect.Slice { | |||||
hasAssigned = true | hasAssigned = true | ||||
case core.Complex128Type: | |||||
var x complex128 | |||||
if len([]byte(vv.String())) > 0 { | |||||
err := json.Unmarshal([]byte(vv.String()), &x) | |||||
x := reflect.New(fieldType) | |||||
if len(vv.Bytes()) > 0 { | |||||
err := json.Unmarshal(vv.Bytes(), x.Interface()) | |||||
if err != nil { | if err != nil { | ||||
return nil, err | return nil, err | ||||
} | } | ||||
fieldValue.Set(reflect.ValueOf(&x)) | |||||
fieldValue.Set(x.Elem()) | |||||
} | } | ||||
hasAssigned = true | |||||
} // switch fieldType | |||||
} // switch fieldType.Kind() | |||||
// !nashtsai! for value can't be assigned directly fallback to convert to []byte then back to value | |||||
if !hasAssigned { | |||||
data, err := value2Bytes(&rawValue) | |||||
} | |||||
} else if session.statement.UseCascade { | |||||
table, err := session.engine.autoMapType(*fieldValue) | |||||
if err != nil { | if err != nil { | ||||
return nil, err | return nil, err | ||||
} | } | ||||
if err = session.bytes2Value(col, fieldValue, data); err != nil { | |||||
hasAssigned = true | |||||
if len(table.PrimaryKeys) != 1 { | |||||
return nil, errors.New("unsupported non or composited primary key cascade") | |||||
} | |||||
var pk = make(core.PK, len(table.PrimaryKeys)) | |||||
pk[0], err = asKind(vv, rawValueType) | |||||
if err != nil { | |||||
return nil, err | return nil, err | ||||
} | } | ||||
if !isPKZero(pk) { | |||||
// !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch | |||||
// however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne | |||||
// property to be fetched lazily | |||||
structInter := reflect.New(fieldValue.Type()) | |||||
has, err := session.ID(pk).NoCascade().get(structInter.Interface()) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
if has { | |||||
fieldValue.Set(structInter.Elem()) | |||||
} else { | |||||
return nil, errors.New("cascade obj is not exist") | |||||
} | |||||
} | |||||
} | |||||
case reflect.Ptr: | |||||
// !nashtsai! TODO merge duplicated codes above | |||||
switch fieldType { | |||||
// following types case matching ptr's native type, therefore assign ptr directly | |||||
case core.PtrStringType: | |||||
if rawValueType.Kind() == reflect.String { | |||||
x := vv.String() | |||||
hasAssigned = true | |||||
fieldValue.Set(reflect.ValueOf(&x)) | |||||
} | |||||
case core.PtrBoolType: | |||||
if rawValueType.Kind() == reflect.Bool { | |||||
x := vv.Bool() | |||||
hasAssigned = true | |||||
fieldValue.Set(reflect.ValueOf(&x)) | |||||
} | |||||
case core.PtrTimeType: | |||||
if rawValueType == core.PtrTimeType { | |||||
hasAssigned = true | |||||
var x = rawValue.Interface().(time.Time) | |||||
fieldValue.Set(reflect.ValueOf(&x)) | |||||
} | |||||
case core.PtrFloat64Type: | |||||
if rawValueType.Kind() == reflect.Float64 { | |||||
x := vv.Float() | |||||
hasAssigned = true | |||||
fieldValue.Set(reflect.ValueOf(&x)) | |||||
} | |||||
case core.PtrUint64Type: | |||||
if rawValueType.Kind() == reflect.Int64 { | |||||
var x = uint64(vv.Int()) | |||||
hasAssigned = true | |||||
fieldValue.Set(reflect.ValueOf(&x)) | |||||
} | |||||
case core.PtrInt64Type: | |||||
if rawValueType.Kind() == reflect.Int64 { | |||||
x := vv.Int() | |||||
hasAssigned = true | |||||
fieldValue.Set(reflect.ValueOf(&x)) | |||||
} | |||||
case core.PtrFloat32Type: | |||||
if rawValueType.Kind() == reflect.Float64 { | |||||
var x = float32(vv.Float()) | |||||
hasAssigned = true | |||||
fieldValue.Set(reflect.ValueOf(&x)) | |||||
} | |||||
case core.PtrIntType: | |||||
if rawValueType.Kind() == reflect.Int64 { | |||||
var x = int(vv.Int()) | |||||
hasAssigned = true | |||||
fieldValue.Set(reflect.ValueOf(&x)) | |||||
} | |||||
case core.PtrInt32Type: | |||||
if rawValueType.Kind() == reflect.Int64 { | |||||
var x = int32(vv.Int()) | |||||
hasAssigned = true | |||||
fieldValue.Set(reflect.ValueOf(&x)) | |||||
} | |||||
case core.PtrInt8Type: | |||||
if rawValueType.Kind() == reflect.Int64 { | |||||
var x = int8(vv.Int()) | |||||
hasAssigned = true | |||||
fieldValue.Set(reflect.ValueOf(&x)) | |||||
} | |||||
case core.PtrInt16Type: | |||||
if rawValueType.Kind() == reflect.Int64 { | |||||
var x = int16(vv.Int()) | |||||
hasAssigned = true | |||||
fieldValue.Set(reflect.ValueOf(&x)) | |||||
} | |||||
case core.PtrUintType: | |||||
if rawValueType.Kind() == reflect.Int64 { | |||||
var x = uint(vv.Int()) | |||||
hasAssigned = true | |||||
fieldValue.Set(reflect.ValueOf(&x)) | |||||
} | |||||
case core.PtrUint32Type: | |||||
if rawValueType.Kind() == reflect.Int64 { | |||||
var x = uint32(vv.Int()) | |||||
hasAssigned = true | |||||
fieldValue.Set(reflect.ValueOf(&x)) | |||||
} | |||||
case core.Uint8Type: | |||||
if rawValueType.Kind() == reflect.Int64 { | |||||
var x = uint8(vv.Int()) | |||||
hasAssigned = true | |||||
fieldValue.Set(reflect.ValueOf(&x)) | |||||
} | |||||
case core.Uint16Type: | |||||
if rawValueType.Kind() == reflect.Int64 { | |||||
var x = uint16(vv.Int()) | |||||
hasAssigned = true | |||||
fieldValue.Set(reflect.ValueOf(&x)) | |||||
} | |||||
case core.Complex64Type: | |||||
var x complex64 | |||||
if len([]byte(vv.String())) > 0 { | |||||
err := json.Unmarshal([]byte(vv.String()), &x) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
fieldValue.Set(reflect.ValueOf(&x)) | |||||
} | |||||
hasAssigned = true | |||||
case core.Complex128Type: | |||||
var x complex128 | |||||
if len([]byte(vv.String())) > 0 { | |||||
err := json.Unmarshal([]byte(vv.String()), &x) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
fieldValue.Set(reflect.ValueOf(&x)) | |||||
} | |||||
hasAssigned = true | |||||
} // switch fieldType | |||||
} // switch fieldType.Kind() | |||||
// !nashtsai! for value can't be assigned directly fallback to convert to []byte then back to value | |||||
if !hasAssigned { | |||||
data, err := value2Bytes(&rawValue) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
if err = session.bytes2Value(col, fieldValue, data); err != nil { | |||||
return nil, err | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -828,15 +834,6 @@ func (session *Session) LastSQL() (string, []interface{}) { | |||||
return session.lastSQL, session.lastSQLArgs | return session.lastSQL, session.lastSQLArgs | ||||
} | } | ||||
// tbName get some table's table name | |||||
func (session *Session) tbNameNoSchema(table *core.Table) string { | |||||
if len(session.statement.AltTableName) > 0 { | |||||
return session.statement.AltTableName | |||||
} | |||||
return table.Name | |||||
} | |||||
// Unscoped always disable struct tag "deleted" | // Unscoped always disable struct tag "deleted" | ||||
func (session *Session) Unscoped() *Session { | func (session *Session) Unscoped() *Session { | ||||
session.statement.Unscoped() | session.statement.Unscoped() | ||||
@@ -4,6 +4,121 @@ | |||||
package xorm | package xorm | ||||
import ( | |||||
"reflect" | |||||
"strings" | |||||
"time" | |||||
"github.com/go-xorm/core" | |||||
) | |||||
type incrParam struct { | |||||
colName string | |||||
arg interface{} | |||||
} | |||||
type decrParam struct { | |||||
colName string | |||||
arg interface{} | |||||
} | |||||
type exprParam struct { | |||||
colName string | |||||
expr string | |||||
} | |||||
type columnMap []string | |||||
func (m columnMap) contain(colName string) bool { | |||||
if len(m) == 0 { | |||||
return false | |||||
} | |||||
n := len(colName) | |||||
for _, mk := range m { | |||||
if len(mk) != n { | |||||
continue | |||||
} | |||||
if strings.EqualFold(mk, colName) { | |||||
return true | |||||
} | |||||
} | |||||
return false | |||||
} | |||||
func (m *columnMap) add(colName string) bool { | |||||
if m.contain(colName) { | |||||
return false | |||||
} | |||||
*m = append(*m, colName) | |||||
return true | |||||
} | |||||
func setColumnInt(bean interface{}, col *core.Column, t int64) { | |||||
v, err := col.ValueOf(bean) | |||||
if err != nil { | |||||
return | |||||
} | |||||
if v.CanSet() { | |||||
switch v.Type().Kind() { | |||||
case reflect.Int, reflect.Int64, reflect.Int32: | |||||
v.SetInt(t) | |||||
case reflect.Uint, reflect.Uint64, reflect.Uint32: | |||||
v.SetUint(uint64(t)) | |||||
} | |||||
} | |||||
} | |||||
func setColumnTime(bean interface{}, col *core.Column, t time.Time) { | |||||
v, err := col.ValueOf(bean) | |||||
if err != nil { | |||||
return | |||||
} | |||||
if v.CanSet() { | |||||
switch v.Type().Kind() { | |||||
case reflect.Struct: | |||||
v.Set(reflect.ValueOf(t).Convert(v.Type())) | |||||
case reflect.Int, reflect.Int64, reflect.Int32: | |||||
v.SetInt(t.Unix()) | |||||
case reflect.Uint, reflect.Uint64, reflect.Uint32: | |||||
v.SetUint(uint64(t.Unix())) | |||||
} | |||||
} | |||||
} | |||||
func getFlagForColumn(m map[string]bool, col *core.Column) (val bool, has bool) { | |||||
if len(m) == 0 { | |||||
return false, false | |||||
} | |||||
n := len(col.Name) | |||||
for mk := range m { | |||||
if len(mk) != n { | |||||
continue | |||||
} | |||||
if strings.EqualFold(mk, col.Name) { | |||||
return m[mk], true | |||||
} | |||||
} | |||||
return false, false | |||||
} | |||||
func col2NewCols(columns ...string) []string { | |||||
newColumns := make([]string, 0, len(columns)) | |||||
for _, col := range columns { | |||||
col = strings.Replace(col, "`", "", -1) | |||||
col = strings.Replace(col, `"`, "", -1) | |||||
ccols := strings.Split(col, ",") | |||||
for _, c := range ccols { | |||||
newColumns = append(newColumns, strings.TrimSpace(c)) | |||||
} | |||||
} | |||||
return newColumns | |||||
} | |||||
// Incr provides a query string like "count = count + 1" | // Incr provides a query string like "count = count + 1" | ||||
func (session *Session) Incr(column string, arg ...interface{}) *Session { | func (session *Session) Incr(column string, arg ...interface{}) *Session { | ||||
session.statement.Incr(column, arg...) | session.statement.Incr(column, arg...) | ||||
@@ -27,7 +27,7 @@ func (session *Session) cacheDelete(table *core.Table, tableName, sqlStr string, | |||||
return ErrCacheFailed | return ErrCacheFailed | ||||
} | } | ||||
cacher := session.engine.getCacher2(table) | |||||
cacher := session.engine.getCacher(tableName) | |||||
pkColumns := table.PKColumns() | pkColumns := table.PKColumns() | ||||
ids, err := core.GetCacheSql(cacher, tableName, newsql, args) | ids, err := core.GetCacheSql(cacher, tableName, newsql, args) | ||||
if err != nil { | if err != nil { | ||||
@@ -79,7 +79,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) { | |||||
defer session.Close() | defer session.Close() | ||||
} | } | ||||
if err := session.statement.setRefValue(rValue(bean)); err != nil { | |||||
if err := session.statement.setRefBean(bean); err != nil { | |||||
return 0, err | return 0, err | ||||
} | } | ||||
@@ -199,7 +199,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) { | |||||
}) | }) | ||||
} | } | ||||
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { | |||||
if cacher := session.engine.getCacher(tableName); cacher != nil && session.statement.UseCache { | |||||
session.cacheDelete(table, tableNameNoQuote, deleteSQL, argsForCache...) | session.cacheDelete(table, tableNameNoQuote, deleteSQL, argsForCache...) | ||||
} | } | ||||
@@ -57,7 +57,7 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) { | |||||
} | } | ||||
if beanValue.Elem().Kind() == reflect.Struct { | if beanValue.Elem().Kind() == reflect.Struct { | ||||
if err := session.statement.setRefValue(beanValue.Elem()); err != nil { | |||||
if err := session.statement.setRefBean(bean[0]); err != nil { | |||||
return false, err | return false, err | ||||
} | } | ||||
} | } | ||||
@@ -29,6 +29,39 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) | |||||
return session.find(rowsSlicePtr, condiBean...) | return session.find(rowsSlicePtr, condiBean...) | ||||
} | } | ||||
// FindAndCount find the results and also return the counts | |||||
func (session *Session) FindAndCount(rowsSlicePtr interface{}, condiBean ...interface{}) (int64, error) { | |||||
if session.isAutoClose { | |||||
defer session.Close() | |||||
} | |||||
session.autoResetStatement = false | |||||
err := session.find(rowsSlicePtr, condiBean...) | |||||
if err != nil { | |||||
return 0, err | |||||
} | |||||
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) | |||||
if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map { | |||||
return 0, errors.New("needs a pointer to a slice or a map") | |||||
} | |||||
sliceElementType := sliceValue.Type().Elem() | |||||
if sliceElementType.Kind() == reflect.Ptr { | |||||
sliceElementType = sliceElementType.Elem() | |||||
} | |||||
session.autoResetStatement = true | |||||
if session.statement.selectStr != "" { | |||||
session.statement.selectStr = "" | |||||
} | |||||
if session.statement.OrderStr != "" { | |||||
session.statement.OrderStr = "" | |||||
} | |||||
return session.Count(reflect.New(sliceElementType).Interface()) | |||||
} | |||||
func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) error { | func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) error { | ||||
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) | sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) | ||||
if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map { | if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map { | ||||
@@ -42,7 +75,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) | |||||
if sliceElementType.Kind() == reflect.Ptr { | if sliceElementType.Kind() == reflect.Ptr { | ||||
if sliceElementType.Elem().Kind() == reflect.Struct { | if sliceElementType.Elem().Kind() == reflect.Struct { | ||||
pv := reflect.New(sliceElementType.Elem()) | pv := reflect.New(sliceElementType.Elem()) | ||||
if err := session.statement.setRefValue(pv.Elem()); err != nil { | |||||
if err := session.statement.setRefValue(pv); err != nil { | |||||
return err | return err | ||||
} | } | ||||
} else { | } else { | ||||
@@ -50,7 +83,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) | |||||
} | } | ||||
} else if sliceElementType.Kind() == reflect.Struct { | } else if sliceElementType.Kind() == reflect.Struct { | ||||
pv := reflect.New(sliceElementType) | pv := reflect.New(sliceElementType) | ||||
if err := session.statement.setRefValue(pv.Elem()); err != nil { | |||||
if err := session.statement.setRefValue(pv); err != nil { | |||||
return err | return err | ||||
} | } | ||||
} else { | } else { | ||||
@@ -128,7 +161,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) | |||||
} | } | ||||
args = append(session.statement.joinArgs, condArgs...) | args = append(session.statement.joinArgs, condArgs...) | ||||
sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL) | |||||
sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL, true, true) | |||||
if err != nil { | if err != nil { | ||||
return err | return err | ||||
} | } | ||||
@@ -143,7 +176,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) | |||||
} | } | ||||
if session.canCache() { | if session.canCache() { | ||||
if cacher := session.engine.getCacher2(table); cacher != nil && | |||||
if cacher := session.engine.getCacher(table.Name); cacher != nil && | |||||
!session.statement.IsDistinct && | !session.statement.IsDistinct && | ||||
!session.statement.unscoped { | !session.statement.unscoped { | ||||
err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...) | err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...) | ||||
@@ -288,6 +321,12 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in | |||||
return ErrCacheFailed | return ErrCacheFailed | ||||
} | } | ||||
tableName := session.statement.TableName() | |||||
cacher := session.engine.getCacher(tableName) | |||||
if cacher == nil { | |||||
return nil | |||||
} | |||||
for _, filter := range session.engine.dialect.Filters() { | for _, filter := range session.engine.dialect.Filters() { | ||||
sqlStr = filter.Do(sqlStr, session.engine.dialect, session.statement.RefTable) | sqlStr = filter.Do(sqlStr, session.engine.dialect, session.statement.RefTable) | ||||
} | } | ||||
@@ -297,9 +336,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in | |||||
return ErrCacheFailed | return ErrCacheFailed | ||||
} | } | ||||
tableName := session.statement.TableName() | |||||
table := session.statement.RefTable | table := session.statement.RefTable | ||||
cacher := session.engine.getCacher2(table) | |||||
ids, err := core.GetCacheSql(cacher, tableName, newsql, args) | ids, err := core.GetCacheSql(cacher, tableName, newsql, args) | ||||
if err != nil { | if err != nil { | ||||
rows, err := session.queryRows(newsql, args...) | rows, err := session.queryRows(newsql, args...) | ||||
@@ -31,7 +31,7 @@ func (session *Session) get(bean interface{}) (bool, error) { | |||||
} | } | ||||
if beanValue.Elem().Kind() == reflect.Struct { | if beanValue.Elem().Kind() == reflect.Struct { | ||||
if err := session.statement.setRefValue(beanValue.Elem()); err != nil { | |||||
if err := session.statement.setRefBean(bean); err != nil { | |||||
return false, err | return false, err | ||||
} | } | ||||
} | } | ||||
@@ -57,7 +57,7 @@ func (session *Session) get(bean interface{}) (bool, error) { | |||||
table := session.statement.RefTable | table := session.statement.RefTable | ||||
if session.canCache() && beanValue.Elem().Kind() == reflect.Struct { | if session.canCache() && beanValue.Elem().Kind() == reflect.Struct { | ||||
if cacher := session.engine.getCacher2(table); cacher != nil && | |||||
if cacher := session.engine.getCacher(table.Name); cacher != nil && | |||||
!session.statement.unscoped { | !session.statement.unscoped { | ||||
has, err := session.cacheGet(bean, sqlStr, args...) | has, err := session.cacheGet(bean, sqlStr, args...) | ||||
if err != ErrCacheFailed { | if err != ErrCacheFailed { | ||||
@@ -134,8 +134,9 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf | |||||
return false, ErrCacheFailed | return false, ErrCacheFailed | ||||
} | } | ||||
cacher := session.engine.getCacher2(session.statement.RefTable) | |||||
tableName := session.statement.TableName() | tableName := session.statement.TableName() | ||||
cacher := session.engine.getCacher(tableName) | |||||
session.engine.logger.Debug("[cacheGet] find sql:", newsql, args) | session.engine.logger.Debug("[cacheGet] find sql:", newsql, args) | ||||
table := session.statement.RefTable | table := session.statement.RefTable | ||||
ids, err := core.GetCacheSql(cacher, tableName, newsql, args) | ids, err := core.GetCacheSql(cacher, tableName, newsql, args) | ||||
@@ -66,11 +66,12 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error | |||||
return 0, errors.New("could not insert a empty slice") | return 0, errors.New("could not insert a empty slice") | ||||
} | } | ||||
if err := session.statement.setRefValue(reflect.ValueOf(sliceValue.Index(0).Interface())); err != nil { | |||||
if err := session.statement.setRefBean(sliceValue.Index(0).Interface()); err != nil { | |||||
return 0, err | return 0, err | ||||
} | } | ||||
if len(session.statement.TableName()) <= 0 { | |||||
tableName := session.statement.TableName() | |||||
if len(tableName) <= 0 { | |||||
return 0, ErrTableNotFound | return 0, ErrTableNotFound | ||||
} | } | ||||
@@ -115,15 +116,11 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error | |||||
if col.IsDeleted { | if col.IsDeleted { | ||||
continue | continue | ||||
} | } | ||||
if session.statement.ColumnStr != "" { | |||||
if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok { | |||||
continue | |||||
} | |||||
if session.statement.omitColumnMap.contain(col.Name) { | |||||
continue | |||||
} | } | ||||
if session.statement.OmitStr != "" { | |||||
if _, ok := getFlagForColumn(session.statement.columnMap, col); ok { | |||||
continue | |||||
} | |||||
if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { | |||||
continue | |||||
} | } | ||||
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime { | if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime { | ||||
val, t := session.engine.nowTime(col) | val, t := session.engine.nowTime(col) | ||||
@@ -170,15 +167,11 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error | |||||
if col.IsDeleted { | if col.IsDeleted { | ||||
continue | continue | ||||
} | } | ||||
if session.statement.ColumnStr != "" { | |||||
if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok { | |||||
continue | |||||
} | |||||
if session.statement.omitColumnMap.contain(col.Name) { | |||||
continue | |||||
} | } | ||||
if session.statement.OmitStr != "" { | |||||
if _, ok := getFlagForColumn(session.statement.columnMap, col); ok { | |||||
continue | |||||
} | |||||
if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { | |||||
continue | |||||
} | } | ||||
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime { | if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime { | ||||
val, t := session.engine.nowTime(col) | val, t := session.engine.nowTime(col) | ||||
@@ -211,38 +204,33 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error | |||||
} | } | ||||
cleanupProcessorsClosures(&session.beforeClosures) | cleanupProcessorsClosures(&session.beforeClosures) | ||||
var sql = "INSERT INTO %s (%v%v%v) VALUES (%v)" | |||||
var statement string | |||||
var tableName = session.statement.TableName() | |||||
var sql string | |||||
if session.engine.dialect.DBType() == core.ORACLE { | if session.engine.dialect.DBType() == core.ORACLE { | ||||
sql = "INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL" | |||||
temp := fmt.Sprintf(") INTO %s (%v%v%v) VALUES (", | temp := fmt.Sprintf(") INTO %s (%v%v%v) VALUES (", | ||||
session.engine.Quote(tableName), | session.engine.Quote(tableName), | ||||
session.engine.QuoteStr(), | session.engine.QuoteStr(), | ||||
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()), | strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()), | ||||
session.engine.QuoteStr()) | session.engine.QuoteStr()) | ||||
statement = fmt.Sprintf(sql, | |||||
sql = fmt.Sprintf("INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL", | |||||
session.engine.Quote(tableName), | session.engine.Quote(tableName), | ||||
session.engine.QuoteStr(), | session.engine.QuoteStr(), | ||||
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()), | strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()), | ||||
session.engine.QuoteStr(), | session.engine.QuoteStr(), | ||||
strings.Join(colMultiPlaces, temp)) | strings.Join(colMultiPlaces, temp)) | ||||
} else { | } else { | ||||
statement = fmt.Sprintf(sql, | |||||
sql = fmt.Sprintf("INSERT INTO %s (%v%v%v) VALUES (%v)", | |||||
session.engine.Quote(tableName), | session.engine.Quote(tableName), | ||||
session.engine.QuoteStr(), | session.engine.QuoteStr(), | ||||
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()), | strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()), | ||||
session.engine.QuoteStr(), | session.engine.QuoteStr(), | ||||
strings.Join(colMultiPlaces, "),(")) | strings.Join(colMultiPlaces, "),(")) | ||||
} | } | ||||
res, err := session.exec(statement, args...) | |||||
res, err := session.exec(sql, args...) | |||||
if err != nil { | if err != nil { | ||||
return 0, err | return 0, err | ||||
} | } | ||||
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { | |||||
session.cacheInsert(table, tableName) | |||||
} | |||||
session.cacheInsert(tableName) | |||||
lenAfterClosures := len(session.afterClosures) | lenAfterClosures := len(session.afterClosures) | ||||
for i := 0; i < size; i++ { | for i := 0; i < size; i++ { | ||||
@@ -298,7 +286,7 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) { | |||||
} | } | ||||
func (session *Session) innerInsert(bean interface{}) (int64, error) { | func (session *Session) innerInsert(bean interface{}) (int64, error) { | ||||
if err := session.statement.setRefValue(rValue(bean)); err != nil { | |||||
if err := session.statement.setRefBean(bean); err != nil { | |||||
return 0, err | return 0, err | ||||
} | } | ||||
if len(session.statement.TableName()) <= 0 { | if len(session.statement.TableName()) <= 0 { | ||||
@@ -316,8 +304,8 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { | |||||
if processor, ok := interface{}(bean).(BeforeInsertProcessor); ok { | if processor, ok := interface{}(bean).(BeforeInsertProcessor); ok { | ||||
processor.BeforeInsert() | processor.BeforeInsert() | ||||
} | } | ||||
// -- | |||||
colNames, args, err := genCols(session.statement.RefTable, session, bean, false, false) | |||||
colNames, args, err := session.genInsertColumns(bean) | |||||
if err != nil { | if err != nil { | ||||
return 0, err | return 0, err | ||||
} | } | ||||
@@ -402,9 +390,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { | |||||
defer handleAfterInsertProcessorFunc(bean) | defer handleAfterInsertProcessorFunc(bean) | ||||
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { | |||||
session.cacheInsert(table, tableName) | |||||
} | |||||
session.cacheInsert(tableName) | |||||
if table.Version != "" && session.statement.checkVersion { | if table.Version != "" && session.statement.checkVersion { | ||||
verValue, err := table.VersionColumn().ValueOf(bean) | verValue, err := table.VersionColumn().ValueOf(bean) | ||||
@@ -447,9 +433,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { | |||||
} | } | ||||
defer handleAfterInsertProcessorFunc(bean) | defer handleAfterInsertProcessorFunc(bean) | ||||
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { | |||||
session.cacheInsert(table, tableName) | |||||
} | |||||
session.cacheInsert(tableName) | |||||
if table.Version != "" && session.statement.checkVersion { | if table.Version != "" && session.statement.checkVersion { | ||||
verValue, err := table.VersionColumn().ValueOf(bean) | verValue, err := table.VersionColumn().ValueOf(bean) | ||||
@@ -490,9 +474,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { | |||||
defer handleAfterInsertProcessorFunc(bean) | defer handleAfterInsertProcessorFunc(bean) | ||||
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { | |||||
session.cacheInsert(table, tableName) | |||||
} | |||||
session.cacheInsert(tableName) | |||||
if table.Version != "" && session.statement.checkVersion { | if table.Version != "" && session.statement.checkVersion { | ||||
verValue, err := table.VersionColumn().ValueOf(bean) | verValue, err := table.VersionColumn().ValueOf(bean) | ||||
@@ -539,16 +521,104 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) { | |||||
return session.innerInsert(bean) | return session.innerInsert(bean) | ||||
} | } | ||||
func (session *Session) cacheInsert(table *core.Table, tables ...string) error { | |||||
if table == nil { | |||||
return ErrCacheFailed | |||||
func (session *Session) cacheInsert(table string) error { | |||||
if !session.statement.UseCache { | |||||
return nil | |||||
} | } | ||||
cacher := session.engine.getCacher2(table) | |||||
for _, t := range tables { | |||||
session.engine.logger.Debug("[cache] clear sql:", t) | |||||
cacher.ClearIds(t) | |||||
cacher := session.engine.getCacher(table) | |||||
if cacher == nil { | |||||
return nil | |||||
} | } | ||||
session.engine.logger.Debug("[cache] clear sql:", table) | |||||
cacher.ClearIds(table) | |||||
return nil | return nil | ||||
} | } | ||||
// genInsertColumns generates insert needed columns | |||||
func (session *Session) genInsertColumns(bean interface{}) ([]string, []interface{}, error) { | |||||
table := session.statement.RefTable | |||||
colNames := make([]string, 0, len(table.ColumnsSeq())) | |||||
args := make([]interface{}, 0, len(table.ColumnsSeq())) | |||||
for _, col := range table.Columns() { | |||||
if col.MapType == core.ONLYFROMDB { | |||||
continue | |||||
} | |||||
if col.IsDeleted { | |||||
continue | |||||
} | |||||
if session.statement.omitColumnMap.contain(col.Name) { | |||||
continue | |||||
} | |||||
if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { | |||||
continue | |||||
} | |||||
if _, ok := session.statement.incrColumns[col.Name]; ok { | |||||
continue | |||||
} else if _, ok := session.statement.decrColumns[col.Name]; ok { | |||||
continue | |||||
} | |||||
fieldValuePtr, err := col.ValueOf(bean) | |||||
if err != nil { | |||||
return nil, nil, err | |||||
} | |||||
fieldValue := *fieldValuePtr | |||||
if col.IsAutoIncrement { | |||||
switch fieldValue.Type().Kind() { | |||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64: | |||||
if fieldValue.Int() == 0 { | |||||
continue | |||||
} | |||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64: | |||||
if fieldValue.Uint() == 0 { | |||||
continue | |||||
} | |||||
case reflect.String: | |||||
if len(fieldValue.String()) == 0 { | |||||
continue | |||||
} | |||||
case reflect.Ptr: | |||||
if fieldValue.Pointer() == 0 { | |||||
continue | |||||
} | |||||
} | |||||
} | |||||
// !evalphobia! set fieldValue as nil when column is nullable and zero-value | |||||
if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok { | |||||
if col.Nullable && isZero(fieldValue.Interface()) { | |||||
var nilValue *int | |||||
fieldValue = reflect.ValueOf(nilValue) | |||||
} | |||||
} | |||||
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ { | |||||
// if time is non-empty, then set to auto time | |||||
val, t := session.engine.nowTime(col) | |||||
args = append(args, val) | |||||
var colName = col.Name | |||||
session.afterClosures = append(session.afterClosures, func(bean interface{}) { | |||||
col := table.GetColumn(colName) | |||||
setColumnTime(bean, col, t) | |||||
}) | |||||
} else if col.IsVersion && session.statement.checkVersion { | |||||
args = append(args, 1) | |||||
} else { | |||||
arg, err := session.value2Interface(col, fieldValue) | |||||
if err != nil { | |||||
return colNames, args, err | |||||
} | |||||
args = append(args, arg) | |||||
} | |||||
colNames = append(colNames, col.Name) | |||||
} | |||||
return colNames, args, nil | |||||
} |
@@ -64,13 +64,17 @@ func (session *Session) genQuerySQL(sqlorArgs ...interface{}) (string, []interfa | |||||
} | } | ||||
} | } | ||||
if err := session.statement.processIDParam(); err != nil { | |||||
return "", nil, err | |||||
} | |||||
condSQL, condArgs, err := builder.ToSQL(session.statement.cond) | condSQL, condArgs, err := builder.ToSQL(session.statement.cond) | ||||
if err != nil { | if err != nil { | ||||
return "", nil, err | return "", nil, err | ||||
} | } | ||||
args := append(session.statement.joinArgs, condArgs...) | args := append(session.statement.joinArgs, condArgs...) | ||||
sqlStr, err := session.statement.genSelectSQL(columnStr, condSQL) | |||||
sqlStr, err := session.statement.genSelectSQL(columnStr, condSQL, true, true) | |||||
if err != nil { | if err != nil { | ||||
return "", nil, err | return "", nil, err | ||||
} | } | ||||
@@ -6,9 +6,7 @@ package xorm | |||||
import ( | import ( | ||||
"database/sql" | "database/sql" | ||||
"errors" | |||||
"fmt" | "fmt" | ||||
"reflect" | |||||
"strings" | "strings" | ||||
"github.com/go-xorm/core" | "github.com/go-xorm/core" | ||||
@@ -34,8 +32,7 @@ func (session *Session) CreateTable(bean interface{}) error { | |||||
} | } | ||||
func (session *Session) createTable(bean interface{}) error { | func (session *Session) createTable(bean interface{}) error { | ||||
v := rValue(bean) | |||||
if err := session.statement.setRefValue(v); err != nil { | |||||
if err := session.statement.setRefBean(bean); err != nil { | |||||
return err | return err | ||||
} | } | ||||
@@ -54,8 +51,7 @@ func (session *Session) CreateIndexes(bean interface{}) error { | |||||
} | } | ||||
func (session *Session) createIndexes(bean interface{}) error { | func (session *Session) createIndexes(bean interface{}) error { | ||||
v := rValue(bean) | |||||
if err := session.statement.setRefValue(v); err != nil { | |||||
if err := session.statement.setRefBean(bean); err != nil { | |||||
return err | return err | ||||
} | } | ||||
@@ -78,8 +74,7 @@ func (session *Session) CreateUniques(bean interface{}) error { | |||||
} | } | ||||
func (session *Session) createUniques(bean interface{}) error { | func (session *Session) createUniques(bean interface{}) error { | ||||
v := rValue(bean) | |||||
if err := session.statement.setRefValue(v); err != nil { | |||||
if err := session.statement.setRefBean(bean); err != nil { | |||||
return err | return err | ||||
} | } | ||||
@@ -103,8 +98,7 @@ func (session *Session) DropIndexes(bean interface{}) error { | |||||
} | } | ||||
func (session *Session) dropIndexes(bean interface{}) error { | func (session *Session) dropIndexes(bean interface{}) error { | ||||
v := rValue(bean) | |||||
if err := session.statement.setRefValue(v); err != nil { | |||||
if err := session.statement.setRefBean(bean); err != nil { | |||||
return err | return err | ||||
} | } | ||||
@@ -128,11 +122,7 @@ func (session *Session) DropTable(beanOrTableName interface{}) error { | |||||
} | } | ||||
func (session *Session) dropTable(beanOrTableName interface{}) error { | func (session *Session) dropTable(beanOrTableName interface{}) error { | ||||
tableName, err := session.engine.tableName(beanOrTableName) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
tableName := session.engine.TableName(beanOrTableName) | |||||
var needDrop = true | var needDrop = true | ||||
if !session.engine.dialect.SupportDropIfExists() { | if !session.engine.dialect.SupportDropIfExists() { | ||||
sqlStr, args := session.engine.dialect.TableCheckSql(tableName) | sqlStr, args := session.engine.dialect.TableCheckSql(tableName) | ||||
@@ -144,8 +134,8 @@ func (session *Session) dropTable(beanOrTableName interface{}) error { | |||||
} | } | ||||
if needDrop { | if needDrop { | ||||
sqlStr := session.engine.Dialect().DropTableSql(tableName) | |||||
_, err = session.exec(sqlStr) | |||||
sqlStr := session.engine.Dialect().DropTableSql(session.engine.TableName(tableName, true)) | |||||
_, err := session.exec(sqlStr) | |||||
return err | return err | ||||
} | } | ||||
return nil | return nil | ||||
@@ -157,10 +147,7 @@ func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error) | |||||
defer session.Close() | defer session.Close() | ||||
} | } | ||||
tableName, err := session.engine.tableName(beanOrTableName) | |||||
if err != nil { | |||||
return false, err | |||||
} | |||||
tableName := session.engine.TableName(beanOrTableName) | |||||
return session.isTableExist(tableName) | return session.isTableExist(tableName) | ||||
} | } | ||||
@@ -173,24 +160,15 @@ func (session *Session) isTableExist(tableName string) (bool, error) { | |||||
// IsTableEmpty if table have any records | // IsTableEmpty if table have any records | ||||
func (session *Session) IsTableEmpty(bean interface{}) (bool, error) { | func (session *Session) IsTableEmpty(bean interface{}) (bool, error) { | ||||
v := rValue(bean) | |||||
t := v.Type() | |||||
if t.Kind() == reflect.String { | |||||
if session.isAutoClose { | |||||
defer session.Close() | |||||
} | |||||
return session.isTableEmpty(bean.(string)) | |||||
} else if t.Kind() == reflect.Struct { | |||||
rows, err := session.Count(bean) | |||||
return rows == 0, err | |||||
if session.isAutoClose { | |||||
defer session.Close() | |||||
} | } | ||||
return false, errors.New("bean should be a struct or struct's point") | |||||
return session.isTableEmpty(session.engine.TableName(bean)) | |||||
} | } | ||||
func (session *Session) isTableEmpty(tableName string) (bool, error) { | func (session *Session) isTableEmpty(tableName string) (bool, error) { | ||||
var total int64 | var total int64 | ||||
sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(tableName)) | |||||
sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(session.engine.TableName(tableName, true))) | |||||
err := session.queryRow(sqlStr).Scan(&total) | err := session.queryRow(sqlStr).Scan(&total) | ||||
if err != nil { | if err != nil { | ||||
if err == sql.ErrNoRows { | if err == sql.ErrNoRows { | ||||
@@ -255,6 +233,12 @@ func (session *Session) Sync2(beans ...interface{}) error { | |||||
return err | return err | ||||
} | } | ||||
session.autoResetStatement = false | |||||
defer func() { | |||||
session.autoResetStatement = true | |||||
session.resetStatement() | |||||
}() | |||||
var structTables []*core.Table | var structTables []*core.Table | ||||
for _, bean := range beans { | for _, bean := range beans { | ||||
@@ -264,7 +248,8 @@ func (session *Session) Sync2(beans ...interface{}) error { | |||||
return err | return err | ||||
} | } | ||||
structTables = append(structTables, table) | structTables = append(structTables, table) | ||||
var tbName = session.tbNameNoSchema(table) | |||||
tbName := engine.TableName(bean) | |||||
tbNameWithSchema := engine.TableName(tbName, true) | |||||
var oriTable *core.Table | var oriTable *core.Table | ||||
for _, tb := range tables { | for _, tb := range tables { | ||||
@@ -309,32 +294,32 @@ func (session *Session) Sync2(beans ...interface{}) error { | |||||
if engine.dialect.DBType() == core.MYSQL || | if engine.dialect.DBType() == core.MYSQL || | ||||
engine.dialect.DBType() == core.POSTGRES { | engine.dialect.DBType() == core.POSTGRES { | ||||
engine.logger.Infof("Table %s column %s change type from %s to %s\n", | engine.logger.Infof("Table %s column %s change type from %s to %s\n", | ||||
tbName, col.Name, curType, expectedType) | |||||
_, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col)) | |||||
tbNameWithSchema, col.Name, curType, expectedType) | |||||
_, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col)) | |||||
} else { | } else { | ||||
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n", | engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n", | ||||
tbName, col.Name, curType, expectedType) | |||||
tbNameWithSchema, col.Name, curType, expectedType) | |||||
} | } | ||||
} else if strings.HasPrefix(curType, core.Varchar) && strings.HasPrefix(expectedType, core.Varchar) { | } else if strings.HasPrefix(curType, core.Varchar) && strings.HasPrefix(expectedType, core.Varchar) { | ||||
if engine.dialect.DBType() == core.MYSQL { | if engine.dialect.DBType() == core.MYSQL { | ||||
if oriCol.Length < col.Length { | if oriCol.Length < col.Length { | ||||
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", | engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", | ||||
tbName, col.Name, oriCol.Length, col.Length) | |||||
_, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col)) | |||||
tbNameWithSchema, col.Name, oriCol.Length, col.Length) | |||||
_, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col)) | |||||
} | } | ||||
} | } | ||||
} else { | } else { | ||||
if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') { | if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') { | ||||
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s", | engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s", | ||||
tbName, col.Name, curType, expectedType) | |||||
tbNameWithSchema, col.Name, curType, expectedType) | |||||
} | } | ||||
} | } | ||||
} else if expectedType == core.Varchar { | } else if expectedType == core.Varchar { | ||||
if engine.dialect.DBType() == core.MYSQL { | if engine.dialect.DBType() == core.MYSQL { | ||||
if oriCol.Length < col.Length { | if oriCol.Length < col.Length { | ||||
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", | engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", | ||||
tbName, col.Name, oriCol.Length, col.Length) | |||||
_, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col)) | |||||
tbNameWithSchema, col.Name, oriCol.Length, col.Length) | |||||
_, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col)) | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -348,7 +333,7 @@ func (session *Session) Sync2(beans ...interface{}) error { | |||||
} | } | ||||
} else { | } else { | ||||
session.statement.RefTable = table | session.statement.RefTable = table | ||||
session.statement.tableName = tbName | |||||
session.statement.tableName = tbNameWithSchema | |||||
err = session.addColumn(col.Name) | err = session.addColumn(col.Name) | ||||
} | } | ||||
if err != nil { | if err != nil { | ||||
@@ -371,7 +356,7 @@ func (session *Session) Sync2(beans ...interface{}) error { | |||||
if oriIndex != nil { | if oriIndex != nil { | ||||
if oriIndex.Type != index.Type { | if oriIndex.Type != index.Type { | ||||
sql := engine.dialect.DropIndexSql(tbName, oriIndex) | |||||
sql := engine.dialect.DropIndexSql(tbNameWithSchema, oriIndex) | |||||
_, err = session.exec(sql) | _, err = session.exec(sql) | ||||
if err != nil { | if err != nil { | ||||
return err | return err | ||||
@@ -387,7 +372,7 @@ func (session *Session) Sync2(beans ...interface{}) error { | |||||
for name2, index2 := range oriTable.Indexes { | for name2, index2 := range oriTable.Indexes { | ||||
if _, ok := foundIndexNames[name2]; !ok { | if _, ok := foundIndexNames[name2]; !ok { | ||||
sql := engine.dialect.DropIndexSql(tbName, index2) | |||||
sql := engine.dialect.DropIndexSql(tbNameWithSchema, index2) | |||||
_, err = session.exec(sql) | _, err = session.exec(sql) | ||||
if err != nil { | if err != nil { | ||||
return err | return err | ||||
@@ -398,12 +383,12 @@ func (session *Session) Sync2(beans ...interface{}) error { | |||||
for name, index := range addedNames { | for name, index := range addedNames { | ||||
if index.Type == core.UniqueType { | if index.Type == core.UniqueType { | ||||
session.statement.RefTable = table | session.statement.RefTable = table | ||||
session.statement.tableName = tbName | |||||
err = session.addUnique(tbName, name) | |||||
session.statement.tableName = tbNameWithSchema | |||||
err = session.addUnique(tbNameWithSchema, name) | |||||
} else if index.Type == core.IndexType { | } else if index.Type == core.IndexType { | ||||
session.statement.RefTable = table | session.statement.RefTable = table | ||||
session.statement.tableName = tbName | |||||
err = session.addIndex(tbName, name) | |||||
session.statement.tableName = tbNameWithSchema | |||||
err = session.addIndex(tbNameWithSchema, name) | |||||
} | } | ||||
if err != nil { | if err != nil { | ||||
return err | return err | ||||
@@ -428,7 +413,7 @@ func (session *Session) Sync2(beans ...interface{}) error { | |||||
for _, colName := range table.ColumnsSeq() { | for _, colName := range table.ColumnsSeq() { | ||||
if oriTable.GetColumn(colName) == nil { | if oriTable.GetColumn(colName) == nil { | ||||
engine.logger.Warnf("Table %s has column %s but struct has not related field", table.Name, colName) | |||||
engine.logger.Warnf("Table %s has column %s but struct has not related field", engine.TableName(table.Name, true), colName) | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -24,6 +24,7 @@ func (session *Session) Rollback() error { | |||||
if !session.isAutoCommit && !session.isCommitedOrRollbacked { | if !session.isAutoCommit && !session.isCommitedOrRollbacked { | ||||
session.saveLastSQL(session.engine.dialect.RollBackStr()) | session.saveLastSQL(session.engine.dialect.RollBackStr()) | ||||
session.isCommitedOrRollbacked = true | session.isCommitedOrRollbacked = true | ||||
session.isAutoCommit = true | |||||
return session.tx.Rollback() | return session.tx.Rollback() | ||||
} | } | ||||
return nil | return nil | ||||
@@ -34,6 +35,7 @@ func (session *Session) Commit() error { | |||||
if !session.isAutoCommit && !session.isCommitedOrRollbacked { | if !session.isAutoCommit && !session.isCommitedOrRollbacked { | ||||
session.saveLastSQL("COMMIT") | session.saveLastSQL("COMMIT") | ||||
session.isCommitedOrRollbacked = true | session.isCommitedOrRollbacked = true | ||||
session.isAutoCommit = true | |||||
var err error | var err error | ||||
if err = session.tx.Commit(); err == nil { | if err = session.tx.Commit(); err == nil { | ||||
// handle processors after tx committed | // handle processors after tx committed | ||||
@@ -40,7 +40,7 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string, | |||||
} | } | ||||
} | } | ||||
cacher := session.engine.getCacher2(table) | |||||
cacher := session.engine.getCacher(tableName) | |||||
session.engine.logger.Debug("[cacheUpdate] get cache sql", newsql, args[nStart:]) | session.engine.logger.Debug("[cacheUpdate] get cache sql", newsql, args[nStart:]) | ||||
ids, err := core.GetCacheSql(cacher, tableName, newsql, args[nStart:]) | ids, err := core.GetCacheSql(cacher, tableName, newsql, args[nStart:]) | ||||
if err != nil { | if err != nil { | ||||
@@ -167,7 +167,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 | |||||
var isMap = t.Kind() == reflect.Map | var isMap = t.Kind() == reflect.Map | ||||
var isStruct = t.Kind() == reflect.Struct | var isStruct = t.Kind() == reflect.Struct | ||||
if isStruct { | if isStruct { | ||||
if err := session.statement.setRefValue(v); err != nil { | |||||
if err := session.statement.setRefBean(bean); err != nil { | |||||
return 0, err | return 0, err | ||||
} | } | ||||
@@ -176,12 +176,10 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 | |||||
} | } | ||||
if session.statement.ColumnStr == "" { | if session.statement.ColumnStr == "" { | ||||
colNames, args = buildUpdates(session.engine, session.statement.RefTable, bean, false, false, | |||||
false, false, session.statement.allUseBool, session.statement.useAllCols, | |||||
session.statement.mustColumnMap, session.statement.nullableMap, | |||||
session.statement.columnMap, true, session.statement.unscoped) | |||||
colNames, args = session.statement.buildUpdates(bean, false, false, | |||||
false, false, true) | |||||
} else { | } else { | ||||
colNames, args, err = genCols(session.statement.RefTable, session, bean, true, true) | |||||
colNames, args, err = session.genUpdateColumns(bean) | |||||
if err != nil { | if err != nil { | ||||
return 0, err | return 0, err | ||||
} | } | ||||
@@ -202,7 +200,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 | |||||
table := session.statement.RefTable | table := session.statement.RefTable | ||||
if session.statement.UseAutoTime && table != nil && table.Updated != "" { | if session.statement.UseAutoTime && table != nil && table.Updated != "" { | ||||
if _, ok := session.statement.columnMap[strings.ToLower(table.Updated)]; !ok { | |||||
if !session.statement.columnMap.contain(table.Updated) && | |||||
!session.statement.omitColumnMap.contain(table.Updated) { | |||||
colNames = append(colNames, session.engine.Quote(table.Updated)+" = ?") | colNames = append(colNames, session.engine.Quote(table.Updated)+" = ?") | ||||
col := table.UpdatedColumn() | col := table.UpdatedColumn() | ||||
val, t := session.engine.nowTime(col) | val, t := session.engine.nowTime(col) | ||||
@@ -362,12 +361,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 | |||||
} | } | ||||
} | } | ||||
if table != nil { | |||||
if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { | |||||
//session.cacheUpdate(table, tableName, sqlStr, args...) | |||||
cacher.ClearIds(tableName) | |||||
cacher.ClearBeans(tableName) | |||||
} | |||||
if cacher := session.engine.getCacher(tableName); cacher != nil && session.statement.UseCache { | |||||
//session.cacheUpdate(table, tableName, sqlStr, args...) | |||||
session.engine.logger.Debug("[cacheUpdate] clear table ", tableName) | |||||
cacher.ClearIds(tableName) | |||||
cacher.ClearBeans(tableName) | |||||
} | } | ||||
// handle after update processors | // handle after update processors | ||||
@@ -402,3 +400,92 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 | |||||
return res.RowsAffected() | return res.RowsAffected() | ||||
} | } | ||||
func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interface{}, error) { | |||||
table := session.statement.RefTable | |||||
colNames := make([]string, 0, len(table.ColumnsSeq())) | |||||
args := make([]interface{}, 0, len(table.ColumnsSeq())) | |||||
for _, col := range table.Columns() { | |||||
if !col.IsVersion && !col.IsCreated && !col.IsUpdated { | |||||
if session.statement.omitColumnMap.contain(col.Name) { | |||||
continue | |||||
} | |||||
} | |||||
if col.MapType == core.ONLYFROMDB { | |||||
continue | |||||
} | |||||
fieldValuePtr, err := col.ValueOf(bean) | |||||
if err != nil { | |||||
return nil, nil, err | |||||
} | |||||
fieldValue := *fieldValuePtr | |||||
if col.IsAutoIncrement { | |||||
switch fieldValue.Type().Kind() { | |||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64: | |||||
if fieldValue.Int() == 0 { | |||||
continue | |||||
} | |||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64: | |||||
if fieldValue.Uint() == 0 { | |||||
continue | |||||
} | |||||
case reflect.String: | |||||
if len(fieldValue.String()) == 0 { | |||||
continue | |||||
} | |||||
case reflect.Ptr: | |||||
if fieldValue.Pointer() == 0 { | |||||
continue | |||||
} | |||||
} | |||||
} | |||||
if col.IsDeleted || col.IsCreated { | |||||
continue | |||||
} | |||||
if len(session.statement.columnMap) > 0 { | |||||
if !session.statement.columnMap.contain(col.Name) { | |||||
continue | |||||
} else if _, ok := session.statement.incrColumns[col.Name]; ok { | |||||
continue | |||||
} else if _, ok := session.statement.decrColumns[col.Name]; ok { | |||||
continue | |||||
} | |||||
} | |||||
// !evalphobia! set fieldValue as nil when column is nullable and zero-value | |||||
if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok { | |||||
if col.Nullable && isZero(fieldValue.Interface()) { | |||||
var nilValue *int | |||||
fieldValue = reflect.ValueOf(nilValue) | |||||
} | |||||
} | |||||
if col.IsUpdated && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ { | |||||
// if time is non-empty, then set to auto time | |||||
val, t := session.engine.nowTime(col) | |||||
args = append(args, val) | |||||
var colName = col.Name | |||||
session.afterClosures = append(session.afterClosures, func(bean interface{}) { | |||||
col := table.GetColumn(colName) | |||||
setColumnTime(bean, col, t) | |||||
}) | |||||
} else if col.IsVersion && session.statement.checkVersion { | |||||
args = append(args, 1) | |||||
} else { | |||||
arg, err := session.value2Interface(col, fieldValue) | |||||
if err != nil { | |||||
return colNames, args, err | |||||
} | |||||
args = append(args, arg) | |||||
} | |||||
colNames = append(colNames, session.engine.Quote(col.Name)+" = ?") | |||||
} | |||||
return colNames, args, nil | |||||
} |
@@ -5,7 +5,6 @@ | |||||
package xorm | package xorm | ||||
import ( | import ( | ||||
"bytes" | |||||
"database/sql/driver" | "database/sql/driver" | ||||
"encoding/json" | "encoding/json" | ||||
"errors" | "errors" | ||||
@@ -18,21 +17,6 @@ import ( | |||||
"github.com/go-xorm/core" | "github.com/go-xorm/core" | ||||
) | ) | ||||
type incrParam struct { | |||||
colName string | |||||
arg interface{} | |||||
} | |||||
type decrParam struct { | |||||
colName string | |||||
arg interface{} | |||||
} | |||||
type exprParam struct { | |||||
colName string | |||||
expr string | |||||
} | |||||
// Statement save all the sql info for executing SQL | // Statement save all the sql info for executing SQL | ||||
type Statement struct { | type Statement struct { | ||||
RefTable *core.Table | RefTable *core.Table | ||||
@@ -47,7 +31,6 @@ type Statement struct { | |||||
HavingStr string | HavingStr string | ||||
ColumnStr string | ColumnStr string | ||||
selectStr string | selectStr string | ||||
columnMap map[string]bool | |||||
useAllCols bool | useAllCols bool | ||||
OmitStr string | OmitStr string | ||||
AltTableName string | AltTableName string | ||||
@@ -67,6 +50,8 @@ type Statement struct { | |||||
allUseBool bool | allUseBool bool | ||||
checkVersion bool | checkVersion bool | ||||
unscoped bool | unscoped bool | ||||
columnMap columnMap | |||||
omitColumnMap columnMap | |||||
mustColumnMap map[string]bool | mustColumnMap map[string]bool | ||||
nullableMap map[string]bool | nullableMap map[string]bool | ||||
incrColumns map[string]incrParam | incrColumns map[string]incrParam | ||||
@@ -89,7 +74,8 @@ func (statement *Statement) Init() { | |||||
statement.HavingStr = "" | statement.HavingStr = "" | ||||
statement.ColumnStr = "" | statement.ColumnStr = "" | ||||
statement.OmitStr = "" | statement.OmitStr = "" | ||||
statement.columnMap = make(map[string]bool) | |||||
statement.columnMap = columnMap{} | |||||
statement.omitColumnMap = columnMap{} | |||||
statement.AltTableName = "" | statement.AltTableName = "" | ||||
statement.tableName = "" | statement.tableName = "" | ||||
statement.idParam = nil | statement.idParam = nil | ||||
@@ -221,34 +207,33 @@ func (statement *Statement) setRefValue(v reflect.Value) error { | |||||
if err != nil { | if err != nil { | ||||
return err | return err | ||||
} | } | ||||
statement.tableName = statement.Engine.tbName(v) | |||||
statement.tableName = statement.Engine.TableName(v, true) | |||||
return nil | return nil | ||||
} | } | ||||
// Table tempororily set table name, the parameter could be a string or a pointer of struct | |||||
func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { | |||||
v := rValue(tableNameOrBean) | |||||
t := v.Type() | |||||
if t.Kind() == reflect.String { | |||||
statement.AltTableName = tableNameOrBean.(string) | |||||
} else if t.Kind() == reflect.Struct { | |||||
var err error | |||||
statement.RefTable, err = statement.Engine.autoMapType(v) | |||||
if err != nil { | |||||
statement.Engine.logger.Error(err) | |||||
return statement | |||||
} | |||||
statement.AltTableName = statement.Engine.tbName(v) | |||||
func (statement *Statement) setRefBean(bean interface{}) error { | |||||
var err error | |||||
statement.RefTable, err = statement.Engine.autoMapType(rValue(bean)) | |||||
if err != nil { | |||||
return err | |||||
} | } | ||||
return statement | |||||
statement.tableName = statement.Engine.TableName(bean, true) | |||||
return nil | |||||
} | } | ||||
// Auto generating update columnes and values according a struct | // Auto generating update columnes and values according a struct | ||||
func buildUpdates(engine *Engine, table *core.Table, bean interface{}, | |||||
includeVersion bool, includeUpdated bool, includeNil bool, | |||||
includeAutoIncr bool, allUseBool bool, useAllCols bool, | |||||
mustColumnMap map[string]bool, nullableMap map[string]bool, | |||||
columnMap map[string]bool, update, unscoped bool) ([]string, []interface{}) { | |||||
func (statement *Statement) buildUpdates(bean interface{}, | |||||
includeVersion, includeUpdated, includeNil, | |||||
includeAutoIncr, update bool) ([]string, []interface{}) { | |||||
engine := statement.Engine | |||||
table := statement.RefTable | |||||
allUseBool := statement.allUseBool | |||||
useAllCols := statement.useAllCols | |||||
mustColumnMap := statement.mustColumnMap | |||||
nullableMap := statement.nullableMap | |||||
columnMap := statement.columnMap | |||||
omitColumnMap := statement.omitColumnMap | |||||
unscoped := statement.unscoped | |||||
var colNames = make([]string, 0) | var colNames = make([]string, 0) | ||||
var args = make([]interface{}, 0) | var args = make([]interface{}, 0) | ||||
@@ -268,7 +253,14 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, | |||||
if col.IsDeleted && !unscoped { | if col.IsDeleted && !unscoped { | ||||
continue | continue | ||||
} | } | ||||
if use, ok := columnMap[strings.ToLower(col.Name)]; ok && !use { | |||||
if omitColumnMap.contain(col.Name) { | |||||
continue | |||||
} | |||||
if len(columnMap) > 0 && !columnMap.contain(col.Name) { | |||||
continue | |||||
} | |||||
if col.MapType == core.ONLYFROMDB { | |||||
continue | continue | ||||
} | } | ||||
@@ -604,17 +596,10 @@ func (statement *Statement) col2NewColsWithQuote(columns ...string) []string { | |||||
} | } | ||||
func (statement *Statement) colmap2NewColsWithQuote() []string { | func (statement *Statement) colmap2NewColsWithQuote() []string { | ||||
newColumns := make([]string, 0, len(statement.columnMap)) | |||||
for col := range statement.columnMap { | |||||
fields := strings.Split(strings.TrimSpace(col), ".") | |||||
if len(fields) == 1 { | |||||
newColumns = append(newColumns, statement.Engine.quote(fields[0])) | |||||
} else if len(fields) == 2 { | |||||
newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+ | |||||
statement.Engine.quote(fields[1])) | |||||
} else { | |||||
panic(errors.New("unwanted colnames")) | |||||
} | |||||
newColumns := make([]string, len(statement.columnMap), len(statement.columnMap)) | |||||
copy(newColumns, statement.columnMap) | |||||
for i := 0; i < len(statement.columnMap); i++ { | |||||
newColumns[i] = statement.Engine.Quote(newColumns[i]) | |||||
} | } | ||||
return newColumns | return newColumns | ||||
} | } | ||||
@@ -642,10 +627,11 @@ func (statement *Statement) Select(str string) *Statement { | |||||
func (statement *Statement) Cols(columns ...string) *Statement { | func (statement *Statement) Cols(columns ...string) *Statement { | ||||
cols := col2NewCols(columns...) | cols := col2NewCols(columns...) | ||||
for _, nc := range cols { | for _, nc := range cols { | ||||
statement.columnMap[strings.ToLower(nc)] = true | |||||
statement.columnMap.add(nc) | |||||
} | } | ||||
newColumns := statement.colmap2NewColsWithQuote() | newColumns := statement.colmap2NewColsWithQuote() | ||||
statement.ColumnStr = strings.Join(newColumns, ", ") | statement.ColumnStr = strings.Join(newColumns, ", ") | ||||
statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1) | statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1) | ||||
return statement | return statement | ||||
@@ -680,7 +666,7 @@ func (statement *Statement) UseBool(columns ...string) *Statement { | |||||
func (statement *Statement) Omit(columns ...string) { | func (statement *Statement) Omit(columns ...string) { | ||||
newColumns := col2NewCols(columns...) | newColumns := col2NewCols(columns...) | ||||
for _, nc := range newColumns { | for _, nc := range newColumns { | ||||
statement.columnMap[strings.ToLower(nc)] = false | |||||
statement.omitColumnMap = append(statement.omitColumnMap, nc) | |||||
} | } | ||||
statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", "))) | statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", "))) | ||||
} | } | ||||
@@ -719,10 +705,9 @@ func (statement *Statement) OrderBy(order string) *Statement { | |||||
// Desc generate `ORDER BY xx DESC` | // Desc generate `ORDER BY xx DESC` | ||||
func (statement *Statement) Desc(colNames ...string) *Statement { | func (statement *Statement) Desc(colNames ...string) *Statement { | ||||
var buf bytes.Buffer | |||||
fmt.Fprintf(&buf, statement.OrderStr) | |||||
var buf builder.StringBuilder | |||||
if len(statement.OrderStr) > 0 { | if len(statement.OrderStr) > 0 { | ||||
fmt.Fprint(&buf, ", ") | |||||
fmt.Fprint(&buf, statement.OrderStr, ", ") | |||||
} | } | ||||
newColNames := statement.col2NewColsWithQuote(colNames...) | newColNames := statement.col2NewColsWithQuote(colNames...) | ||||
fmt.Fprintf(&buf, "%v DESC", strings.Join(newColNames, " DESC, ")) | fmt.Fprintf(&buf, "%v DESC", strings.Join(newColNames, " DESC, ")) | ||||
@@ -732,10 +717,9 @@ func (statement *Statement) Desc(colNames ...string) *Statement { | |||||
// Asc provide asc order by query condition, the input parameters are columns. | // Asc provide asc order by query condition, the input parameters are columns. | ||||
func (statement *Statement) Asc(colNames ...string) *Statement { | func (statement *Statement) Asc(colNames ...string) *Statement { | ||||
var buf bytes.Buffer | |||||
fmt.Fprintf(&buf, statement.OrderStr) | |||||
var buf builder.StringBuilder | |||||
if len(statement.OrderStr) > 0 { | if len(statement.OrderStr) > 0 { | ||||
fmt.Fprint(&buf, ", ") | |||||
fmt.Fprint(&buf, statement.OrderStr, ", ") | |||||
} | } | ||||
newColNames := statement.col2NewColsWithQuote(colNames...) | newColNames := statement.col2NewColsWithQuote(colNames...) | ||||
fmt.Fprintf(&buf, "%v ASC", strings.Join(newColNames, " ASC, ")) | fmt.Fprintf(&buf, "%v ASC", strings.Join(newColNames, " ASC, ")) | ||||
@@ -743,48 +727,35 @@ func (statement *Statement) Asc(colNames ...string) *Statement { | |||||
return statement | return statement | ||||
} | } | ||||
// Table tempororily set table name, the parameter could be a string or a pointer of struct | |||||
func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { | |||||
v := rValue(tableNameOrBean) | |||||
t := v.Type() | |||||
if t.Kind() == reflect.Struct { | |||||
var err error | |||||
statement.RefTable, err = statement.Engine.autoMapType(v) | |||||
if err != nil { | |||||
statement.Engine.logger.Error(err) | |||||
return statement | |||||
} | |||||
} | |||||
statement.AltTableName = statement.Engine.TableName(tableNameOrBean, true) | |||||
return statement | |||||
} | |||||
// Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN | // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN | ||||
func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement { | func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement { | ||||
var buf bytes.Buffer | |||||
var buf builder.StringBuilder | |||||
if len(statement.JoinStr) > 0 { | if len(statement.JoinStr) > 0 { | ||||
fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP) | fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP) | ||||
} else { | } else { | ||||
fmt.Fprintf(&buf, "%v JOIN ", joinOP) | fmt.Fprintf(&buf, "%v JOIN ", joinOP) | ||||
} | } | ||||
switch tablename.(type) { | |||||
case []string: | |||||
t := tablename.([]string) | |||||
if len(t) > 1 { | |||||
fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(t[0]), statement.Engine.Quote(t[1])) | |||||
} else if len(t) == 1 { | |||||
fmt.Fprintf(&buf, statement.Engine.Quote(t[0])) | |||||
} | |||||
case []interface{}: | |||||
t := tablename.([]interface{}) | |||||
l := len(t) | |||||
var table string | |||||
if l > 0 { | |||||
f := t[0] | |||||
v := rValue(f) | |||||
t := v.Type() | |||||
if t.Kind() == reflect.String { | |||||
table = f.(string) | |||||
} else if t.Kind() == reflect.Struct { | |||||
table = statement.Engine.tbName(v) | |||||
} | |||||
} | |||||
if l > 1 { | |||||
fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(table), | |||||
statement.Engine.Quote(fmt.Sprintf("%v", t[1]))) | |||||
} else if l == 1 { | |||||
fmt.Fprintf(&buf, statement.Engine.Quote(table)) | |||||
} | |||||
default: | |||||
fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", tablename))) | |||||
} | |||||
tbName := statement.Engine.TableName(tablename, true) | |||||
fmt.Fprintf(&buf, " ON %v", condition) | |||||
fmt.Fprintf(&buf, "%s ON %v", tbName, condition) | |||||
statement.JoinStr = buf.String() | statement.JoinStr = buf.String() | ||||
statement.joinArgs = append(statement.joinArgs, args...) | statement.joinArgs = append(statement.joinArgs, args...) | ||||
return statement | return statement | ||||
@@ -809,18 +780,20 @@ func (statement *Statement) Unscoped() *Statement { | |||||
} | } | ||||
func (statement *Statement) genColumnStr() string { | func (statement *Statement) genColumnStr() string { | ||||
var buf bytes.Buffer | |||||
if statement.RefTable == nil { | if statement.RefTable == nil { | ||||
return "" | return "" | ||||
} | } | ||||
var buf builder.StringBuilder | |||||
columns := statement.RefTable.Columns() | columns := statement.RefTable.Columns() | ||||
for _, col := range columns { | for _, col := range columns { | ||||
if statement.OmitStr != "" { | |||||
if _, ok := getFlagForColumn(statement.columnMap, col); ok { | |||||
continue | |||||
} | |||||
if statement.omitColumnMap.contain(col.Name) { | |||||
continue | |||||
} | |||||
if len(statement.columnMap) > 0 && !statement.columnMap.contain(col.Name) { | |||||
continue | |||||
} | } | ||||
if col.MapType == core.ONLYTODB { | if col.MapType == core.ONLYTODB { | ||||
@@ -831,10 +804,6 @@ func (statement *Statement) genColumnStr() string { | |||||
buf.WriteString(", ") | buf.WriteString(", ") | ||||
} | } | ||||
if col.IsPrimaryKey && statement.Engine.Dialect().DBType() == "ql" { | |||||
buf.WriteString("id() AS ") | |||||
} | |||||
if statement.JoinStr != "" { | if statement.JoinStr != "" { | ||||
if statement.TableAlias != "" { | if statement.TableAlias != "" { | ||||
buf.WriteString(statement.TableAlias) | buf.WriteString(statement.TableAlias) | ||||
@@ -859,11 +828,13 @@ func (statement *Statement) genCreateTableSQL() string { | |||||
func (statement *Statement) genIndexSQL() []string { | func (statement *Statement) genIndexSQL() []string { | ||||
var sqls []string | var sqls []string | ||||
tbName := statement.TableName() | tbName := statement.TableName() | ||||
quote := statement.Engine.Quote | |||||
for idxName, index := range statement.RefTable.Indexes { | |||||
for _, index := range statement.RefTable.Indexes { | |||||
if index.Type == core.IndexType { | if index.Type == core.IndexType { | ||||
sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(tbName, idxName)), | |||||
quote(tbName), quote(strings.Join(index.Cols, quote(",")))) | |||||
sql := statement.Engine.dialect.CreateIndexSql(tbName, index) | |||||
/*idxTBName := strings.Replace(tbName, ".", "_", -1) | |||||
idxTBName = strings.Replace(idxTBName, `"`, "", -1) | |||||
sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(idxTBName, idxName)), | |||||
quote(tbName), quote(strings.Join(index.Cols, quote(","))))*/ | |||||
sqls = append(sqls, sql) | sqls = append(sqls, sql) | ||||
} | } | ||||
} | } | ||||
@@ -889,16 +860,18 @@ func (statement *Statement) genUniqueSQL() []string { | |||||
func (statement *Statement) genDelIndexSQL() []string { | func (statement *Statement) genDelIndexSQL() []string { | ||||
var sqls []string | var sqls []string | ||||
tbName := statement.TableName() | tbName := statement.TableName() | ||||
idxPrefixName := strings.Replace(tbName, `"`, "", -1) | |||||
idxPrefixName = strings.Replace(idxPrefixName, `.`, "_", -1) | |||||
for idxName, index := range statement.RefTable.Indexes { | for idxName, index := range statement.RefTable.Indexes { | ||||
var rIdxName string | var rIdxName string | ||||
if index.Type == core.UniqueType { | if index.Type == core.UniqueType { | ||||
rIdxName = uniqueName(tbName, idxName) | |||||
rIdxName = uniqueName(idxPrefixName, idxName) | |||||
} else if index.Type == core.IndexType { | } else if index.Type == core.IndexType { | ||||
rIdxName = indexName(tbName, idxName) | |||||
rIdxName = indexName(idxPrefixName, idxName) | |||||
} | } | ||||
sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(rIdxName)) | |||||
sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(statement.Engine.TableName(rIdxName, true))) | |||||
if statement.Engine.dialect.IndexOnTable() { | if statement.Engine.dialect.IndexOnTable() { | ||||
sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(statement.TableName())) | |||||
sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName)) | |||||
} | } | ||||
sqls = append(sqls, sql) | sqls = append(sqls, sql) | ||||
} | } | ||||
@@ -949,7 +922,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, | |||||
v := rValue(bean) | v := rValue(bean) | ||||
isStruct := v.Kind() == reflect.Struct | isStruct := v.Kind() == reflect.Struct | ||||
if isStruct { | if isStruct { | ||||
statement.setRefValue(v) | |||||
statement.setRefBean(bean) | |||||
} | } | ||||
var columnStr = statement.ColumnStr | var columnStr = statement.ColumnStr | ||||
@@ -982,13 +955,17 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, | |||||
if err := statement.mergeConds(bean); err != nil { | if err := statement.mergeConds(bean); err != nil { | ||||
return "", nil, err | return "", nil, err | ||||
} | } | ||||
} else { | |||||
if err := statement.processIDParam(); err != nil { | |||||
return "", nil, err | |||||
} | |||||
} | } | ||||
condSQL, condArgs, err := builder.ToSQL(statement.cond) | condSQL, condArgs, err := builder.ToSQL(statement.cond) | ||||
if err != nil { | if err != nil { | ||||
return "", nil, err | return "", nil, err | ||||
} | } | ||||
sqlStr, err := statement.genSelectSQL(columnStr, condSQL) | |||||
sqlStr, err := statement.genSelectSQL(columnStr, condSQL, true, true) | |||||
if err != nil { | if err != nil { | ||||
return "", nil, err | return "", nil, err | ||||
} | } | ||||
@@ -1001,7 +978,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa | |||||
var condArgs []interface{} | var condArgs []interface{} | ||||
var err error | var err error | ||||
if len(beans) > 0 { | if len(beans) > 0 { | ||||
statement.setRefValue(rValue(beans[0])) | |||||
statement.setRefBean(beans[0]) | |||||
condSQL, condArgs, err = statement.genConds(beans[0]) | condSQL, condArgs, err = statement.genConds(beans[0]) | ||||
} else { | } else { | ||||
condSQL, condArgs, err = builder.ToSQL(statement.cond) | condSQL, condArgs, err = builder.ToSQL(statement.cond) | ||||
@@ -1018,7 +995,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa | |||||
selectSQL = "count(*)" | selectSQL = "count(*)" | ||||
} | } | ||||
} | } | ||||
sqlStr, err := statement.genSelectSQL(selectSQL, condSQL) | |||||
sqlStr, err := statement.genSelectSQL(selectSQL, condSQL, false, false) | |||||
if err != nil { | if err != nil { | ||||
return "", nil, err | return "", nil, err | ||||
} | } | ||||
@@ -1027,7 +1004,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa | |||||
} | } | ||||
func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) { | func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) { | ||||
statement.setRefValue(rValue(bean)) | |||||
statement.setRefBean(bean) | |||||
var sumStrs = make([]string, 0, len(columns)) | var sumStrs = make([]string, 0, len(columns)) | ||||
for _, colName := range columns { | for _, colName := range columns { | ||||
@@ -1043,7 +1020,7 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri | |||||
return "", nil, err | return "", nil, err | ||||
} | } | ||||
sqlStr, err := statement.genSelectSQL(sumSelect, condSQL) | |||||
sqlStr, err := statement.genSelectSQL(sumSelect, condSQL, true, true) | |||||
if err != nil { | if err != nil { | ||||
return "", nil, err | return "", nil, err | ||||
} | } | ||||
@@ -1051,27 +1028,20 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri | |||||
return sqlStr, append(statement.joinArgs, condArgs...), nil | return sqlStr, append(statement.joinArgs, condArgs...), nil | ||||
} | } | ||||
func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, err error) { | |||||
var distinct string | |||||
func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (string, error) { | |||||
var ( | |||||
distinct string | |||||
dialect = statement.Engine.Dialect() | |||||
quote = statement.Engine.Quote | |||||
fromStr = " FROM " | |||||
top, mssqlCondi, whereStr string | |||||
) | |||||
if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") { | if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") { | ||||
distinct = "DISTINCT " | distinct = "DISTINCT " | ||||
} | } | ||||
var dialect = statement.Engine.Dialect() | |||||
var quote = statement.Engine.Quote | |||||
var top string | |||||
var mssqlCondi string | |||||
if err := statement.processIDParam(); err != nil { | |||||
return "", err | |||||
} | |||||
var buf bytes.Buffer | |||||
if len(condSQL) > 0 { | if len(condSQL) > 0 { | ||||
fmt.Fprintf(&buf, " WHERE %v", condSQL) | |||||
whereStr = " WHERE " + condSQL | |||||
} | } | ||||
var whereStr = buf.String() | |||||
var fromStr = " FROM " | |||||
if dialect.DBType() == core.MSSQL && strings.Contains(statement.TableName(), "..") { | if dialect.DBType() == core.MSSQL && strings.Contains(statement.TableName(), "..") { | ||||
fromStr += statement.TableName() | fromStr += statement.TableName() | ||||
@@ -1118,9 +1088,10 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, e | |||||
} | } | ||||
var orderStr string | var orderStr string | ||||
if len(statement.OrderStr) > 0 { | |||||
if needOrderBy && len(statement.OrderStr) > 0 { | |||||
orderStr = " ORDER BY " + statement.OrderStr | orderStr = " ORDER BY " + statement.OrderStr | ||||
} | } | ||||
var groupStr string | var groupStr string | ||||
if len(statement.GroupByStr) > 0 { | if len(statement.GroupByStr) > 0 { | ||||
groupStr = " GROUP BY " + statement.GroupByStr | groupStr = " GROUP BY " + statement.GroupByStr | ||||
@@ -1130,45 +1101,50 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, e | |||||
} | } | ||||
} | } | ||||
// !nashtsai! REVIEW Sprintf is considered slowest mean of string concatnation, better to work with builder pattern | |||||
a = fmt.Sprintf("SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr) | |||||
var buf builder.StringBuilder | |||||
fmt.Fprintf(&buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr) | |||||
if len(mssqlCondi) > 0 { | if len(mssqlCondi) > 0 { | ||||
if len(whereStr) > 0 { | if len(whereStr) > 0 { | ||||
a += " AND " + mssqlCondi | |||||
fmt.Fprint(&buf, " AND ", mssqlCondi) | |||||
} else { | } else { | ||||
a += " WHERE " + mssqlCondi | |||||
fmt.Fprint(&buf, " WHERE ", mssqlCondi) | |||||
} | } | ||||
} | } | ||||
if statement.GroupByStr != "" { | if statement.GroupByStr != "" { | ||||
a = fmt.Sprintf("%v GROUP BY %v", a, statement.GroupByStr) | |||||
fmt.Fprint(&buf, " GROUP BY ", statement.GroupByStr) | |||||
} | } | ||||
if statement.HavingStr != "" { | if statement.HavingStr != "" { | ||||
a = fmt.Sprintf("%v %v", a, statement.HavingStr) | |||||
fmt.Fprint(&buf, " ", statement.HavingStr) | |||||
} | } | ||||
if statement.OrderStr != "" { | |||||
a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr) | |||||
if needOrderBy && statement.OrderStr != "" { | |||||
fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr) | |||||
} | } | ||||
if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE { | |||||
if statement.Start > 0 { | |||||
a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start) | |||||
} else if statement.LimitN > 0 { | |||||
a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN) | |||||
} | |||||
} else if dialect.DBType() == core.ORACLE { | |||||
if statement.Start != 0 || statement.LimitN != 0 { | |||||
a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start) | |||||
if needLimit { | |||||
if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE { | |||||
if statement.Start > 0 { | |||||
fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", statement.LimitN, statement.Start) | |||||
} else if statement.LimitN > 0 { | |||||
fmt.Fprint(&buf, " LIMIT ", statement.LimitN) | |||||
} | |||||
} else if dialect.DBType() == core.ORACLE { | |||||
if statement.Start != 0 || statement.LimitN != 0 { | |||||
oldString := buf.String() | |||||
buf.Reset() | |||||
fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", | |||||
columnStr, columnStr, oldString, statement.Start+statement.LimitN, statement.Start) | |||||
} | |||||
} | } | ||||
} | } | ||||
if statement.IsForUpdate { | if statement.IsForUpdate { | ||||
a = dialect.ForUpdateSql(a) | |||||
return dialect.ForUpdateSql(buf.String()), nil | |||||
} | } | ||||
return | |||||
return buf.String(), nil | |||||
} | } | ||||
func (statement *Statement) processIDParam() error { | func (statement *Statement) processIDParam() error { | ||||
if statement.idParam == nil { | |||||
if statement.idParam == nil || statement.RefTable == nil { | |||||
return nil | return nil | ||||
} | } | ||||
@@ -17,7 +17,7 @@ import ( | |||||
const ( | const ( | ||||
// Version show the xorm's version | // Version show the xorm's version | ||||
Version string = "0.6.4.0910" | |||||
Version string = "0.7.0.0504" | |||||
) | ) | ||||
func regDrvsNDialects() bool { | func regDrvsNDialects() bool { | ||||
@@ -31,7 +31,7 @@ func regDrvsNDialects() bool { | |||||
"mysql": {"mysql", func() core.Driver { return &mysqlDriver{} }, func() core.Dialect { return &mysql{} }}, | "mysql": {"mysql", func() core.Driver { return &mysqlDriver{} }, func() core.Dialect { return &mysql{} }}, | ||||
"mymysql": {"mysql", func() core.Driver { return &mymysqlDriver{} }, func() core.Dialect { return &mysql{} }}, | "mymysql": {"mysql", func() core.Driver { return &mymysqlDriver{} }, func() core.Dialect { return &mysql{} }}, | ||||
"postgres": {"postgres", func() core.Driver { return &pqDriver{} }, func() core.Dialect { return &postgres{} }}, | "postgres": {"postgres", func() core.Driver { return &pqDriver{} }, func() core.Dialect { return &postgres{} }}, | ||||
"pgx": {"postgres", func() core.Driver { return &pqDriver{} }, func() core.Dialect { return &postgres{} }}, | |||||
"pgx": {"postgres", func() core.Driver { return &pqDriverPgx{} }, func() core.Dialect { return &postgres{} }}, | |||||
"sqlite3": {"sqlite3", func() core.Driver { return &sqlite3Driver{} }, func() core.Dialect { return &sqlite3{} }}, | "sqlite3": {"sqlite3", func() core.Driver { return &sqlite3Driver{} }, func() core.Dialect { return &sqlite3{} }}, | ||||
"oci8": {"oracle", func() core.Driver { return &oci8Driver{} }, func() core.Dialect { return &oracle{} }}, | "oci8": {"oracle", func() core.Driver { return &oci8Driver{} }, func() core.Dialect { return &oracle{} }}, | ||||
"goracle": {"oracle", func() core.Driver { return &goracleDriver{} }, func() core.Dialect { return &oracle{} }}, | "goracle": {"oracle", func() core.Driver { return &goracleDriver{} }, func() core.Dialect { return &oracle{} }}, | ||||
@@ -90,6 +90,7 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { | |||||
TagIdentifier: "xorm", | TagIdentifier: "xorm", | ||||
TZLocation: time.Local, | TZLocation: time.Local, | ||||
tagHandlers: defaultTagHandlers, | tagHandlers: defaultTagHandlers, | ||||
cachers: make(map[string]core.Cacher), | |||||
} | } | ||||
if uri.DbType == core.SQLITE { | if uri.DbType == core.SQLITE { | ||||
@@ -108,6 +109,13 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { | |||||
return engine, nil | return engine, nil | ||||
} | } | ||||
// NewEngineWithParams new a db manager with params. The params will be passed to dialect. | |||||
func NewEngineWithParams(driverName string, dataSourceName string, params map[string]string) (*Engine, error) { | |||||
engine, err := NewEngine(driverName, dataSourceName) | |||||
engine.dialect.SetParams(params) | |||||
return engine, err | |||||
} | |||||
// Clone clone an engine | // Clone clone an engine | ||||
func (engine *Engine) Clone() (*Engine, error) { | func (engine *Engine) Clone() (*Engine, error) { | ||||
return NewEngine(engine.DriverName(), engine.DataSourceName()) | return NewEngine(engine.DriverName(), engine.DataSourceName()) | ||||