You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

session_insert.go 18 kB


  1. // Copyright 2016 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package xorm
  5. import (
  6. "errors"
  7. "fmt"
  8. "reflect"
  9. "sort"
  10. "strconv"
  11. "strings"
  12. "xorm.io/builder"
  13. "xorm.io/xorm/internal/utils"
  14. "xorm.io/xorm/schemas"
  15. )
  16. // ErrNoElementsOnSlice represents an error there is no element when insert
  17. var ErrNoElementsOnSlice = errors.New("No element on slice when insert")
  18. // Insert insert one or more beans
  19. func (session *Session) Insert(beans ...interface{}) (int64, error) {
  20. var affected int64
  21. var err error
  22. if session.isAutoClose {
  23. defer session.Close()
  24. }
  25. session.autoResetStatement = false
  26. defer func() {
  27. session.autoResetStatement = true
  28. session.resetStatement()
  29. }()
  30. for _, bean := range beans {
  31. switch bean.(type) {
  32. case map[string]interface{}:
  33. cnt, err := session.insertMapInterface(bean.(map[string]interface{}))
  34. if err != nil {
  35. return affected, err
  36. }
  37. affected += cnt
  38. case []map[string]interface{}:
  39. s := bean.([]map[string]interface{})
  40. for i := 0; i < len(s); i++ {
  41. cnt, err := session.insertMapInterface(s[i])
  42. if err != nil {
  43. return affected, err
  44. }
  45. affected += cnt
  46. }
  47. case map[string]string:
  48. cnt, err := session.insertMapString(bean.(map[string]string))
  49. if err != nil {
  50. return affected, err
  51. }
  52. affected += cnt
  53. case []map[string]string:
  54. s := bean.([]map[string]string)
  55. for i := 0; i < len(s); i++ {
  56. cnt, err := session.insertMapString(s[i])
  57. if err != nil {
  58. return affected, err
  59. }
  60. affected += cnt
  61. }
  62. default:
  63. sliceValue := reflect.Indirect(reflect.ValueOf(bean))
  64. if sliceValue.Kind() == reflect.Slice {
  65. size := sliceValue.Len()
  66. if size <= 0 {
  67. return 0, ErrNoElementsOnSlice
  68. }
  69. cnt, err := session.innerInsertMulti(bean)
  70. if err != nil {
  71. return affected, err
  72. }
  73. affected += cnt
  74. } else {
  75. cnt, err := session.innerInsert(bean)
  76. if err != nil {
  77. return affected, err
  78. }
  79. affected += cnt
  80. }
  81. }
  82. }
  83. return affected, err
  84. }
  85. func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error) {
  86. sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
  87. if sliceValue.Kind() != reflect.Slice {
  88. return 0, errors.New("needs a pointer to a slice")
  89. }
  90. if sliceValue.Len() <= 0 {
  91. return 0, errors.New("could not insert a empty slice")
  92. }
  93. if err := session.statement.SetRefBean(sliceValue.Index(0).Interface()); err != nil {
  94. return 0, err
  95. }
  96. tableName := session.statement.TableName()
  97. if len(tableName) <= 0 {
  98. return 0, ErrTableNotFound
  99. }
  100. table := session.statement.RefTable
  101. size := sliceValue.Len()
  102. var colNames []string
  103. var colMultiPlaces []string
  104. var args []interface{}
  105. var cols []*schemas.Column
  106. for i := 0; i < size; i++ {
  107. v := sliceValue.Index(i)
  108. var vv reflect.Value
  109. switch v.Kind() {
  110. case reflect.Interface:
  111. vv = reflect.Indirect(v.Elem())
  112. default:
  113. vv = reflect.Indirect(v)
  114. }
  115. elemValue := v.Interface()
  116. var colPlaces []string
  117. // handle BeforeInsertProcessor
  118. // !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi??
  119. for _, closure := range session.beforeClosures {
  120. closure(elemValue)
  121. }
  122. if processor, ok := interface{}(elemValue).(BeforeInsertProcessor); ok {
  123. processor.BeforeInsert()
  124. }
  125. // --
  126. for _, col := range table.Columns() {
  127. ptrFieldValue, err := col.ValueOfV(&vv)
  128. if err != nil {
  129. return 0, err
  130. }
  131. fieldValue := *ptrFieldValue
  132. if col.IsAutoIncrement && utils.IsZero(fieldValue.Interface()) {
  133. continue
  134. }
  135. if col.MapType == schemas.ONLYFROMDB {
  136. continue
  137. }
  138. if col.IsDeleted {
  139. continue
  140. }
  141. if session.statement.OmitColumnMap.Contain(col.Name) {
  142. continue
  143. }
  144. if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(col.Name) {
  145. continue
  146. }
  147. if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime {
  148. val, t := session.engine.nowTime(col)
  149. args = append(args, val)
  150. var colName = col.Name
  151. session.afterClosures = append(session.afterClosures, func(bean interface{}) {
  152. col := table.GetColumn(colName)
  153. setColumnTime(bean, col, t)
  154. })
  155. } else if col.IsVersion && session.statement.CheckVersion {
  156. args = append(args, 1)
  157. var colName = col.Name
  158. session.afterClosures = append(session.afterClosures, func(bean interface{}) {
  159. col := table.GetColumn(colName)
  160. setColumnInt(bean, col, 1)
  161. })
  162. } else {
  163. arg, err := session.statement.Value2Interface(col, fieldValue)
  164. if err != nil {
  165. return 0, err
  166. }
  167. args = append(args, arg)
  168. }
  169. if i == 0 {
  170. colNames = append(colNames, col.Name)
  171. cols = append(cols, col)
  172. }
  173. colPlaces = append(colPlaces, "?")
  174. }
  175. colMultiPlaces = append(colMultiPlaces, strings.Join(colPlaces, ", "))
  176. }
  177. cleanupProcessorsClosures(&session.beforeClosures)
  178. quoter := session.engine.dialect.Quoter()
  179. var sql string
  180. colStr := quoter.Join(colNames, ",")
  181. if session.engine.dialect.URI().DBType == schemas.ORACLE {
  182. temp := fmt.Sprintf(") INTO %s (%v) VALUES (",
  183. quoter.Quote(tableName),
  184. colStr)
  185. sql = fmt.Sprintf("INSERT ALL INTO %s (%v) VALUES (%v) SELECT 1 FROM DUAL",
  186. quoter.Quote(tableName),
  187. colStr,
  188. strings.Join(colMultiPlaces, temp))
  189. } else {
  190. sql = fmt.Sprintf("INSERT INTO %s (%v) VALUES (%v)",
  191. quoter.Quote(tableName),
  192. colStr,
  193. strings.Join(colMultiPlaces, "),("))
  194. }
  195. res, err := session.exec(sql, args...)
  196. if err != nil {
  197. return 0, err
  198. }
  199. session.cacheInsert(tableName)
  200. lenAfterClosures := len(session.afterClosures)
  201. for i := 0; i < size; i++ {
  202. elemValue := reflect.Indirect(sliceValue.Index(i)).Addr().Interface()
  203. // handle AfterInsertProcessor
  204. if session.isAutoCommit {
  205. // !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi??
  206. for _, closure := range session.afterClosures {
  207. closure(elemValue)
  208. }
  209. if processor, ok := interface{}(elemValue).(AfterInsertProcessor); ok {
  210. processor.AfterInsert()
  211. }
  212. } else {
  213. if lenAfterClosures > 0 {
  214. if value, has := session.afterInsertBeans[elemValue]; has && value != nil {
  215. *value = append(*value, session.afterClosures...)
  216. } else {
  217. afterClosures := make([]func(interface{}), lenAfterClosures)
  218. copy(afterClosures, session.afterClosures)
  219. session.afterInsertBeans[elemValue] = &afterClosures
  220. }
  221. } else {
  222. if _, ok := interface{}(elemValue).(AfterInsertProcessor); ok {
  223. session.afterInsertBeans[elemValue] = nil
  224. }
  225. }
  226. }
  227. }
  228. cleanupProcessorsClosures(&session.afterClosures)
  229. return res.RowsAffected()
  230. }
  231. // InsertMulti insert multiple records
  232. func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
  233. if session.isAutoClose {
  234. defer session.Close()
  235. }
  236. sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
  237. if sliceValue.Kind() != reflect.Slice {
  238. return 0, ErrParamsType
  239. }
  240. if sliceValue.Len() <= 0 {
  241. return 0, nil
  242. }
  243. return session.innerInsertMulti(rowsSlicePtr)
  244. }
  245. func (session *Session) innerInsert(bean interface{}) (int64, error) {
  246. if err := session.statement.SetRefBean(bean); err != nil {
  247. return 0, err
  248. }
  249. if len(session.statement.TableName()) <= 0 {
  250. return 0, ErrTableNotFound
  251. }
  252. // handle BeforeInsertProcessor
  253. for _, closure := range session.beforeClosures {
  254. closure(bean)
  255. }
  256. cleanupProcessorsClosures(&session.beforeClosures) // cleanup after used
  257. if processor, ok := interface{}(bean).(BeforeInsertProcessor); ok {
  258. processor.BeforeInsert()
  259. }
  260. var tableName = session.statement.TableName()
  261. table := session.statement.RefTable
  262. colNames, args, err := session.genInsertColumns(bean)
  263. if err != nil {
  264. return 0, err
  265. }
  266. sqlStr, args, err := session.statement.GenInsertSQL(colNames, args)
  267. if err != nil {
  268. return 0, err
  269. }
  270. handleAfterInsertProcessorFunc := func(bean interface{}) {
  271. if session.isAutoCommit {
  272. for _, closure := range session.afterClosures {
  273. closure(bean)
  274. }
  275. if processor, ok := interface{}(bean).(AfterInsertProcessor); ok {
  276. processor.AfterInsert()
  277. }
  278. } else {
  279. lenAfterClosures := len(session.afterClosures)
  280. if lenAfterClosures > 0 {
  281. if value, has := session.afterInsertBeans[bean]; has && value != nil {
  282. *value = append(*value, session.afterClosures...)
  283. } else {
  284. afterClosures := make([]func(interface{}), lenAfterClosures)
  285. copy(afterClosures, session.afterClosures)
  286. session.afterInsertBeans[bean] = &afterClosures
  287. }
  288. } else {
  289. if _, ok := interface{}(bean).(AfterInsertProcessor); ok {
  290. session.afterInsertBeans[bean] = nil
  291. }
  292. }
  293. }
  294. cleanupProcessorsClosures(&session.afterClosures) // cleanup after used
  295. }
  296. // for postgres, many of them didn't implement lastInsertId, so we should
  297. // implemented it ourself.
  298. if session.engine.dialect.URI().DBType == schemas.ORACLE && len(table.AutoIncrement) > 0 {
  299. res, err := session.queryBytes("select seq_atable.currval from dual", args...)
  300. if err != nil {
  301. return 0, err
  302. }
  303. defer handleAfterInsertProcessorFunc(bean)
  304. session.cacheInsert(tableName)
  305. if table.Version != "" && session.statement.CheckVersion {
  306. verValue, err := table.VersionColumn().ValueOf(bean)
  307. if err != nil {
  308. session.engine.logger.Errorf("%v", err)
  309. } else if verValue.IsValid() && verValue.CanSet() {
  310. session.incrVersionFieldValue(verValue)
  311. }
  312. }
  313. if len(res) < 1 {
  314. return 0, errors.New("insert no error but not returned id")
  315. }
  316. idByte := res[0][table.AutoIncrement]
  317. id, err := strconv.ParseInt(string(idByte), 10, 64)
  318. if err != nil || id <= 0 {
  319. return 1, err
  320. }
  321. aiValue, err := table.AutoIncrColumn().ValueOf(bean)
  322. if err != nil {
  323. session.engine.logger.Errorf("%v", err)
  324. }
  325. if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
  326. return 1, nil
  327. }
  328. aiValue.Set(int64ToIntValue(id, aiValue.Type()))
  329. return 1, nil
  330. } else if len(table.AutoIncrement) > 0 && (session.engine.dialect.URI().DBType == schemas.POSTGRES ||
  331. session.engine.dialect.URI().DBType == schemas.MSSQL) {
  332. res, err := session.queryBytes(sqlStr, args...)
  333. if err != nil {
  334. return 0, err
  335. }
  336. defer handleAfterInsertProcessorFunc(bean)
  337. session.cacheInsert(tableName)
  338. if table.Version != "" && session.statement.CheckVersion {
  339. verValue, err := table.VersionColumn().ValueOf(bean)
  340. if err != nil {
  341. session.engine.logger.Errorf("%v", err)
  342. } else if verValue.IsValid() && verValue.CanSet() {
  343. session.incrVersionFieldValue(verValue)
  344. }
  345. }
  346. if len(res) < 1 {
  347. return 0, errors.New("insert successfully but not returned id")
  348. }
  349. idByte := res[0][table.AutoIncrement]
  350. id, err := strconv.ParseInt(string(idByte), 10, 64)
  351. if err != nil || id <= 0 {
  352. return 1, err
  353. }
  354. aiValue, err := table.AutoIncrColumn().ValueOf(bean)
  355. if err != nil {
  356. session.engine.logger.Errorf("%v", err)
  357. }
  358. if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
  359. return 1, nil
  360. }
  361. aiValue.Set(int64ToIntValue(id, aiValue.Type()))
  362. return 1, nil
  363. }
  364. res, err := session.exec(sqlStr, args...)
  365. if err != nil {
  366. return 0, err
  367. }
  368. defer handleAfterInsertProcessorFunc(bean)
  369. session.cacheInsert(tableName)
  370. if table.Version != "" && session.statement.CheckVersion {
  371. verValue, err := table.VersionColumn().ValueOf(bean)
  372. if err != nil {
  373. session.engine.logger.Errorf("%v", err)
  374. } else if verValue.IsValid() && verValue.CanSet() {
  375. session.incrVersionFieldValue(verValue)
  376. }
  377. }
  378. if table.AutoIncrement == "" {
  379. return res.RowsAffected()
  380. }
  381. var id int64
  382. id, err = res.LastInsertId()
  383. if err != nil || id <= 0 {
  384. return res.RowsAffected()
  385. }
  386. aiValue, err := table.AutoIncrColumn().ValueOf(bean)
  387. if err != nil {
  388. session.engine.logger.Errorf("%v", err)
  389. }
  390. if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
  391. return res.RowsAffected()
  392. }
  393. aiValue.Set(int64ToIntValue(id, aiValue.Type()))
  394. return res.RowsAffected()
  395. }
  396. // InsertOne insert only one struct into database as a record.
  397. // The in parameter bean must a struct or a point to struct. The return
  398. // parameter is inserted and error
  399. func (session *Session) InsertOne(bean interface{}) (int64, error) {
  400. if session.isAutoClose {
  401. defer session.Close()
  402. }
  403. return session.innerInsert(bean)
  404. }
  405. func (session *Session) cacheInsert(table string) error {
  406. if !session.statement.UseCache {
  407. return nil
  408. }
  409. cacher := session.engine.cacherMgr.GetCacher(table)
  410. if cacher == nil {
  411. return nil
  412. }
  413. session.engine.logger.Debugf("[cache] clear sql: %v", table)
  414. cacher.ClearIds(table)
  415. return nil
  416. }
  417. // genInsertColumns generates insert needed columns
  418. func (session *Session) genInsertColumns(bean interface{}) ([]string, []interface{}, error) {
  419. table := session.statement.RefTable
  420. colNames := make([]string, 0, len(table.ColumnsSeq()))
  421. args := make([]interface{}, 0, len(table.ColumnsSeq()))
  422. for _, col := range table.Columns() {
  423. if col.MapType == schemas.ONLYFROMDB {
  424. continue
  425. }
  426. if col.IsDeleted {
  427. continue
  428. }
  429. if session.statement.OmitColumnMap.Contain(col.Name) {
  430. continue
  431. }
  432. if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(col.Name) {
  433. continue
  434. }
  435. if session.statement.IncrColumns.IsColExist(col.Name) {
  436. continue
  437. } else if session.statement.DecrColumns.IsColExist(col.Name) {
  438. continue
  439. } else if session.statement.ExprColumns.IsColExist(col.Name) {
  440. continue
  441. }
  442. fieldValuePtr, err := col.ValueOf(bean)
  443. if err != nil {
  444. return nil, nil, err
  445. }
  446. fieldValue := *fieldValuePtr
  447. if col.IsAutoIncrement && utils.IsValueZero(fieldValue) {
  448. continue
  449. }
  450. // !evalphobia! set fieldValue as nil when column is nullable and zero-value
  451. if _, ok := getFlagForColumn(session.statement.NullableMap, col); ok {
  452. if col.Nullable && utils.IsValueZero(fieldValue) {
  453. var nilValue *int
  454. fieldValue = reflect.ValueOf(nilValue)
  455. }
  456. }
  457. if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ {
  458. // if time is non-empty, then set to auto time
  459. val, t := session.engine.nowTime(col)
  460. args = append(args, val)
  461. var colName = col.Name
  462. session.afterClosures = append(session.afterClosures, func(bean interface{}) {
  463. col := table.GetColumn(colName)
  464. setColumnTime(bean, col, t)
  465. })
  466. } else if col.IsVersion && session.statement.CheckVersion {
  467. args = append(args, 1)
  468. } else {
  469. arg, err := session.statement.Value2Interface(col, fieldValue)
  470. if err != nil {
  471. return colNames, args, err
  472. }
  473. args = append(args, arg)
  474. }
  475. colNames = append(colNames, col.Name)
  476. }
  477. return colNames, args, nil
  478. }
  479. func (session *Session) insertMapInterface(m map[string]interface{}) (int64, error) {
  480. if len(m) == 0 {
  481. return 0, ErrParamsType
  482. }
  483. tableName := session.statement.TableName()
  484. if len(tableName) <= 0 {
  485. return 0, ErrTableNotFound
  486. }
  487. var columns = make([]string, 0, len(m))
  488. exprs := session.statement.ExprColumns
  489. for k := range m {
  490. if !exprs.IsColExist(k) {
  491. columns = append(columns, k)
  492. }
  493. }
  494. sort.Strings(columns)
  495. var args = make([]interface{}, 0, len(m))
  496. for _, colName := range columns {
  497. args = append(args, m[colName])
  498. }
  499. return session.insertMap(columns, args)
  500. }
  501. func (session *Session) insertMapString(m map[string]string) (int64, error) {
  502. if len(m) == 0 {
  503. return 0, ErrParamsType
  504. }
  505. tableName := session.statement.TableName()
  506. if len(tableName) <= 0 {
  507. return 0, ErrTableNotFound
  508. }
  509. var columns = make([]string, 0, len(m))
  510. exprs := session.statement.ExprColumns
  511. for k := range m {
  512. if !exprs.IsColExist(k) {
  513. columns = append(columns, k)
  514. }
  515. }
  516. sort.Strings(columns)
  517. var args = make([]interface{}, 0, len(m))
  518. for _, colName := range columns {
  519. args = append(args, m[colName])
  520. }
  521. return session.insertMap(columns, args)
  522. }
  523. func (session *Session) insertMap(columns []string, args []interface{}) (int64, error) {
  524. tableName := session.statement.TableName()
  525. if len(tableName) <= 0 {
  526. return 0, ErrTableNotFound
  527. }
  528. exprs := session.statement.ExprColumns
  529. w := builder.NewWriter()
  530. // if insert where
  531. if session.statement.Conds().IsValid() {
  532. if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil {
  533. return 0, err
  534. }
  535. if err := session.engine.dialect.Quoter().JoinWrite(w.Builder, append(columns, exprs.ColNames...), ","); err != nil {
  536. return 0, err
  537. }
  538. if _, err := w.WriteString(") SELECT "); err != nil {
  539. return 0, err
  540. }
  541. if err := session.statement.WriteArgs(w, args); err != nil {
  542. return 0, err
  543. }
  544. if len(exprs.Args) > 0 {
  545. if _, err := w.WriteString(","); err != nil {
  546. return 0, err
  547. }
  548. if err := exprs.WriteArgs(w); err != nil {
  549. return 0, err
  550. }
  551. }
  552. if _, err := w.WriteString(fmt.Sprintf(" FROM %s WHERE ", session.engine.Quote(tableName))); err != nil {
  553. return 0, err
  554. }
  555. if err := session.statement.Conds().WriteTo(w); err != nil {
  556. return 0, err
  557. }
  558. } else {
  559. qm := strings.Repeat("?,", len(columns))
  560. qm = qm[:len(qm)-1]
  561. if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil {
  562. return 0, err
  563. }
  564. if err := session.engine.dialect.Quoter().JoinWrite(w.Builder, append(columns, exprs.ColNames...), ","); err != nil {
  565. return 0, err
  566. }
  567. if _, err := w.WriteString(fmt.Sprintf(") VALUES (%s", qm)); err != nil {
  568. return 0, err
  569. }
  570. w.Append(args...)
  571. if len(exprs.Args) > 0 {
  572. if _, err := w.WriteString(","); err != nil {
  573. return 0, err
  574. }
  575. if err := exprs.WriteArgs(w); err != nil {
  576. return 0, err
  577. }
  578. }
  579. if _, err := w.WriteString(")"); err != nil {
  580. return 0, err
  581. }
  582. }
  583. sql := w.String()
  584. args = w.Args()
  585. if err := session.cacheInsert(tableName); err != nil {
  586. return 0, err
  587. }
  588. res, err := session.exec(sql, args...)
  589. if err != nil {
  590. return 0, err
  591. }
  592. affected, err := res.RowsAffected()
  593. if err != nil {
  594. return 0, err
  595. }
  596. return affected, nil
  597. }