You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mongo.go 14 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484
  1. // Copyright (C) MongoDB, Inc. 2017-present.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License"); you may
  4. // not use this file except in compliance with the License. You may obtain
  5. // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
  6. package mongo // import "go.mongodb.org/mongo-driver/mongo"
  7. import (
  8. "context"
  9. "errors"
  10. "fmt"
  11. "net"
  12. "reflect"
  13. "strconv"
  14. "strings"
  15. "go.mongodb.org/mongo-driver/mongo/options"
  16. "go.mongodb.org/mongo-driver/x/bsonx"
  17. "go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
  18. "go.mongodb.org/mongo-driver/bson"
  19. "go.mongodb.org/mongo-driver/bson/bsoncodec"
  20. "go.mongodb.org/mongo-driver/bson/bsontype"
  21. "go.mongodb.org/mongo-driver/bson/primitive"
  22. )
  23. // Dialer is used to make network connections.
  24. type Dialer interface {
  25. DialContext(ctx context.Context, network, address string) (net.Conn, error)
  26. }
  27. // BSONAppender is an interface implemented by types that can marshal a
  28. // provided type into BSON bytes and append those bytes to the provided []byte.
  29. // The AppendBSON can return a non-nil error and non-nil []byte. The AppendBSON
  30. // method may also write incomplete BSON to the []byte.
  31. type BSONAppender interface {
  32. AppendBSON([]byte, interface{}) ([]byte, error)
  33. }
  34. // BSONAppenderFunc is an adapter function that allows any function that
  35. // satisfies the AppendBSON method signature to be used where a BSONAppender is
  36. // used.
  37. type BSONAppenderFunc func([]byte, interface{}) ([]byte, error)
  38. // AppendBSON implements the BSONAppender interface
  39. func (baf BSONAppenderFunc) AppendBSON(dst []byte, val interface{}) ([]byte, error) {
  40. return baf(dst, val)
  41. }
  42. // MarshalError is returned when attempting to transform a value into a document
  43. // results in an error.
  44. type MarshalError struct {
  45. Value interface{}
  46. Err error
  47. }
  48. // Error implements the error interface.
  49. func (me MarshalError) Error() string {
  50. return fmt.Sprintf("cannot transform type %s to a BSON Document: %v", reflect.TypeOf(me.Value), me.Err)
  51. }
  52. // Pipeline is a type that makes creating aggregation pipelines easier. It is a
  53. // helper and is intended for serializing to BSON.
  54. //
  55. // Example usage:
  56. //
  57. // mongo.Pipeline{
  58. // {{"$group", bson.D{{"_id", "$state"}, {"totalPop", bson.D{{"$sum", "$pop"}}}}}},
  59. // {{"$match", bson.D{{"totalPop", bson.D{{"$gte", 10*1000*1000}}}}}},
  60. // }
  61. //
  62. type Pipeline []bson.D
  63. // transformAndEnsureID is a hack that makes it easy to get a RawValue as the _id value. This will
  64. // be removed when we switch from using bsonx to bsoncore for the driver package.
  65. func transformAndEnsureID(registry *bsoncodec.Registry, val interface{}) (bsonx.Doc, interface{}, error) {
  66. // TODO: performance is going to be pretty bad for bsonx.Doc here since we turn it into a []byte
  67. // only to turn it back into a bsonx.Doc. We can fix this post beta1 when we refactor the driver
  68. // package to use bsoncore.Document instead of bsonx.Doc.
  69. if registry == nil {
  70. registry = bson.NewRegistryBuilder().Build()
  71. }
  72. switch tt := val.(type) {
  73. case nil:
  74. return nil, nil, ErrNilDocument
  75. case bsonx.Doc:
  76. val = tt.Copy()
  77. case []byte:
  78. // Slight optimization so we'll just use MarshalBSON and not go through the codec machinery.
  79. val = bson.Raw(tt)
  80. }
  81. // TODO(skriptble): Use a pool of these instead.
  82. buf := make([]byte, 0, 256)
  83. b, err := bson.MarshalAppendWithRegistry(registry, buf, val)
  84. if err != nil {
  85. return nil, nil, MarshalError{Value: val, Err: err}
  86. }
  87. d, err := bsonx.ReadDoc(b)
  88. if err != nil {
  89. return nil, nil, err
  90. }
  91. var id interface{}
  92. idx := d.IndexOf("_id")
  93. var idElem bsonx.Elem
  94. switch idx {
  95. case -1:
  96. idElem = bsonx.Elem{"_id", bsonx.ObjectID(primitive.NewObjectID())}
  97. d = append(d, bsonx.Elem{})
  98. copy(d[1:], d)
  99. d[0] = idElem
  100. default:
  101. idElem = d[idx]
  102. copy(d[1:idx+1], d[0:idx])
  103. d[0] = idElem
  104. }
  105. idBuf := make([]byte, 0, 256)
  106. t, data, err := idElem.Value.MarshalAppendBSONValue(idBuf[:0])
  107. if err != nil {
  108. return nil, nil, err
  109. }
  110. err = bson.RawValue{Type: t, Value: data}.UnmarshalWithRegistry(registry, &id)
  111. if err != nil {
  112. return nil, nil, err
  113. }
  114. return d, id, nil
  115. }
  116. // transformAndEnsureIDv2 is a hack that makes it easy to get a RawValue as the _id value. This will
  117. // be removed when we switch from using bsonx to bsoncore for the driver package.
  118. func transformAndEnsureIDv2(registry *bsoncodec.Registry, val interface{}) (bsoncore.Document, interface{}, error) {
  119. if registry == nil {
  120. registry = bson.NewRegistryBuilder().Build()
  121. }
  122. switch tt := val.(type) {
  123. case nil:
  124. return nil, nil, ErrNilDocument
  125. case bsonx.Doc:
  126. val = tt.Copy()
  127. case []byte:
  128. // Slight optimization so we'll just use MarshalBSON and not go through the codec machinery.
  129. val = bson.Raw(tt)
  130. }
  131. // TODO(skriptble): Use a pool of these instead.
  132. doc := make(bsoncore.Document, 0, 256)
  133. doc, err := bson.MarshalAppendWithRegistry(registry, doc, val)
  134. if err != nil {
  135. return nil, nil, MarshalError{Value: val, Err: err}
  136. }
  137. var id interface{}
  138. value := doc.Lookup("_id")
  139. switch value.Type {
  140. case bsontype.Type(0):
  141. value = bsoncore.Value{Type: bsontype.ObjectID, Data: bsoncore.AppendObjectID(nil, primitive.NewObjectID())}
  142. olddoc := doc
  143. doc = make(bsoncore.Document, 0, len(olddoc)+17) // type byte + _id + null byte + object ID
  144. _, doc = bsoncore.ReserveLength(doc)
  145. doc = bsoncore.AppendValueElement(doc, "_id", value)
  146. doc = append(doc, olddoc[4:]...) // remove the length
  147. doc = bsoncore.UpdateLength(doc, 0, int32(len(doc)))
  148. default:
  149. // We copy the bytes here to ensure that any bytes returned to the user aren't modified
  150. // later.
  151. buf := make([]byte, len(value.Data))
  152. copy(buf, value.Data)
  153. value.Data = buf
  154. }
  155. err = bson.RawValue{Type: value.Type, Value: value.Data}.UnmarshalWithRegistry(registry, &id)
  156. if err != nil {
  157. return nil, nil, err
  158. }
  159. return doc, id, nil
  160. }
  161. func transformDocument(registry *bsoncodec.Registry, val interface{}) (bsonx.Doc, error) {
  162. if doc, ok := val.(bsonx.Doc); ok {
  163. return doc.Copy(), nil
  164. }
  165. b, err := transformBsoncoreDocument(registry, val)
  166. if err != nil {
  167. return nil, err
  168. }
  169. return bsonx.ReadDoc(b)
  170. }
  171. func transformBsoncoreDocument(registry *bsoncodec.Registry, val interface{}) (bsoncore.Document, error) {
  172. if registry == nil {
  173. registry = bson.DefaultRegistry
  174. }
  175. if val == nil {
  176. return nil, ErrNilDocument
  177. }
  178. if bs, ok := val.([]byte); ok {
  179. // Slight optimization so we'll just use MarshalBSON and not go through the codec machinery.
  180. val = bson.Raw(bs)
  181. }
  182. // TODO(skriptble): Use a pool of these instead.
  183. buf := make([]byte, 0, 256)
  184. b, err := bson.MarshalAppendWithRegistry(registry, buf[:0], val)
  185. if err != nil {
  186. return nil, MarshalError{Value: val, Err: err}
  187. }
  188. return b, nil
  189. }
  190. func ensureID(d bsonx.Doc) (bsonx.Doc, interface{}) {
  191. var id interface{}
  192. elem, err := d.LookupElementErr("_id")
  193. switch err.(type) {
  194. case nil:
  195. id = elem
  196. default:
  197. oid := primitive.NewObjectID()
  198. d = append(d, bsonx.Elem{"_id", bsonx.ObjectID(oid)})
  199. id = oid
  200. }
  201. return d, id
  202. }
  203. func ensureDollarKey(doc bsonx.Doc) error {
  204. if len(doc) == 0 {
  205. return errors.New("update document must have at least one element")
  206. }
  207. if !strings.HasPrefix(doc[0].Key, "$") {
  208. return errors.New("update document must contain key beginning with '$'")
  209. }
  210. return nil
  211. }
  212. func ensureDollarKeyv2(doc bsoncore.Document) error {
  213. firstElem, err := doc.IndexErr(0)
  214. if err != nil {
  215. return errors.New("update document must have at least one element")
  216. }
  217. if !strings.HasPrefix(firstElem.Key(), "$") {
  218. return errors.New("update document must contain key beginning with '$'")
  219. }
  220. return nil
  221. }
  222. func transformAggregatePipeline(registry *bsoncodec.Registry, pipeline interface{}) (bsonx.Arr, error) {
  223. pipelineArr := bsonx.Arr{}
  224. switch t := pipeline.(type) {
  225. case bsoncodec.ValueMarshaler:
  226. btype, val, err := t.MarshalBSONValue()
  227. if err != nil {
  228. return nil, err
  229. }
  230. if btype != bsontype.Array {
  231. return nil, fmt.Errorf("ValueMarshaler returned a %v, but was expecting %v", btype, bsontype.Array)
  232. }
  233. err = pipelineArr.UnmarshalBSONValue(btype, val)
  234. if err != nil {
  235. return nil, err
  236. }
  237. default:
  238. val := reflect.ValueOf(t)
  239. if !val.IsValid() || (val.Kind() != reflect.Slice && val.Kind() != reflect.Array) {
  240. return nil, fmt.Errorf("can only transform slices and arrays into aggregation pipelines, but got %v", val.Kind())
  241. }
  242. for idx := 0; idx < val.Len(); idx++ {
  243. elem, err := transformDocument(registry, val.Index(idx).Interface())
  244. if err != nil {
  245. return nil, err
  246. }
  247. pipelineArr = append(pipelineArr, bsonx.Document(elem))
  248. }
  249. }
  250. return pipelineArr, nil
  251. }
  252. func transformAggregatePipelinev2(registry *bsoncodec.Registry, pipeline interface{}) (bsoncore.Document, bool, error) {
  253. switch t := pipeline.(type) {
  254. case bsoncodec.ValueMarshaler:
  255. btype, val, err := t.MarshalBSONValue()
  256. if err != nil {
  257. return nil, false, err
  258. }
  259. if btype != bsontype.Array {
  260. return nil, false, fmt.Errorf("ValueMarshaler returned a %v, but was expecting %v", btype, bsontype.Array)
  261. }
  262. var hasOutputStage bool
  263. pipelineDoc := bsoncore.Document(val)
  264. if _, err := pipelineDoc.LookupErr("$out"); err == nil {
  265. hasOutputStage = true
  266. }
  267. if _, err := pipelineDoc.LookupErr("$merge"); err == nil {
  268. hasOutputStage = true
  269. }
  270. return pipelineDoc, hasOutputStage, nil
  271. default:
  272. val := reflect.ValueOf(t)
  273. if !val.IsValid() || (val.Kind() != reflect.Slice && val.Kind() != reflect.Array) {
  274. return nil, false, fmt.Errorf("can only transform slices and arrays into aggregation pipelines, but got %v", val.Kind())
  275. }
  276. aidx, arr := bsoncore.AppendArrayStart(nil)
  277. var hasOutputStage bool
  278. valLen := val.Len()
  279. for idx := 0; idx < valLen; idx++ {
  280. doc, err := transformBsoncoreDocument(registry, val.Index(idx).Interface())
  281. if err != nil {
  282. return nil, false, err
  283. }
  284. if idx == valLen-1 {
  285. if elem, err := doc.IndexErr(0); err == nil && (elem.Key() == "$out" || elem.Key() == "$merge") {
  286. hasOutputStage = true
  287. }
  288. }
  289. arr = bsoncore.AppendDocumentElement(arr, strconv.Itoa(idx), doc)
  290. }
  291. arr, _ = bsoncore.AppendArrayEnd(arr, aidx)
  292. return arr, hasOutputStage, nil
  293. }
  294. }
  295. func transformUpdateValue(registry *bsoncodec.Registry, update interface{}, checkDocDollarKey bool) (bsoncore.Value, error) {
  296. var u bsoncore.Value
  297. var err error
  298. switch t := update.(type) {
  299. case nil:
  300. return u, ErrNilDocument
  301. case primitive.D, bsonx.Doc:
  302. u.Type = bsontype.EmbeddedDocument
  303. u.Data, err = transformBsoncoreDocument(registry, update)
  304. if err != nil {
  305. return u, err
  306. }
  307. if checkDocDollarKey {
  308. err = ensureDollarKeyv2(u.Data)
  309. }
  310. return u, err
  311. case bson.Raw:
  312. u.Type = bsontype.EmbeddedDocument
  313. u.Data = t
  314. if checkDocDollarKey {
  315. err = ensureDollarKeyv2(u.Data)
  316. }
  317. return u, err
  318. case bsoncore.Document:
  319. u.Type = bsontype.EmbeddedDocument
  320. u.Data = t
  321. if checkDocDollarKey {
  322. err = ensureDollarKeyv2(u.Data)
  323. }
  324. return u, err
  325. case []byte:
  326. u.Type = bsontype.EmbeddedDocument
  327. u.Data = t
  328. if checkDocDollarKey {
  329. err = ensureDollarKeyv2(u.Data)
  330. }
  331. return u, err
  332. case bsoncodec.Marshaler:
  333. u.Type = bsontype.EmbeddedDocument
  334. u.Data, err = t.MarshalBSON()
  335. if err != nil {
  336. return u, err
  337. }
  338. if checkDocDollarKey {
  339. err = ensureDollarKeyv2(u.Data)
  340. }
  341. return u, err
  342. case bsoncodec.ValueMarshaler:
  343. u.Type, u.Data, err = t.MarshalBSONValue()
  344. if err != nil {
  345. return u, err
  346. }
  347. if u.Type != bsontype.Array && u.Type != bsontype.EmbeddedDocument {
  348. return u, fmt.Errorf("ValueMarshaler returned a %v, but was expecting %v or %v", u.Type, bsontype.Array, bsontype.EmbeddedDocument)
  349. }
  350. return u, err
  351. default:
  352. val := reflect.ValueOf(t)
  353. if !val.IsValid() {
  354. return u, fmt.Errorf("can only transform slices and arrays into update pipelines, but got %v", val.Kind())
  355. }
  356. if val.Kind() != reflect.Slice && val.Kind() != reflect.Array {
  357. u.Type = bsontype.EmbeddedDocument
  358. u.Data, err = transformBsoncoreDocument(registry, update)
  359. if err != nil {
  360. return u, err
  361. }
  362. if checkDocDollarKey {
  363. err = ensureDollarKeyv2(u.Data)
  364. }
  365. return u, err
  366. }
  367. u.Type = bsontype.Array
  368. aidx, arr := bsoncore.AppendArrayStart(nil)
  369. valLen := val.Len()
  370. for idx := 0; idx < valLen; idx++ {
  371. doc, err := transformBsoncoreDocument(registry, val.Index(idx).Interface())
  372. if err != nil {
  373. return u, err
  374. }
  375. if err := ensureDollarKeyv2(doc); err != nil {
  376. return u, err
  377. }
  378. arr = bsoncore.AppendDocumentElement(arr, strconv.Itoa(idx), doc)
  379. }
  380. u.Data, _ = bsoncore.AppendArrayEnd(arr, aidx)
  381. return u, err
  382. }
  383. }
  384. func transformValue(registry *bsoncodec.Registry, val interface{}) (bsoncore.Value, error) {
  385. switch conv := val.(type) {
  386. case string:
  387. return bsoncore.Value{Type: bsontype.String, Data: bsoncore.AppendString(nil, conv)}, nil
  388. default:
  389. doc, err := transformBsoncoreDocument(registry, val)
  390. if err != nil {
  391. return bsoncore.Value{}, err
  392. }
  393. return bsoncore.Value{Type: bsontype.EmbeddedDocument, Data: doc}, nil
  394. }
  395. }
  396. // Build the aggregation pipeline for the CountDocument command.
  397. func countDocumentsAggregatePipeline(registry *bsoncodec.Registry, filter interface{}, opts *options.CountOptions) (bsoncore.Document, error) {
  398. filterDoc, err := transformBsoncoreDocument(registry, filter)
  399. if err != nil {
  400. return nil, err
  401. }
  402. aidx, arr := bsoncore.AppendArrayStart(nil)
  403. didx, arr := bsoncore.AppendDocumentElementStart(arr, strconv.Itoa(0))
  404. arr = bsoncore.AppendDocumentElement(arr, "$match", filterDoc)
  405. arr, _ = bsoncore.AppendDocumentEnd(arr, didx)
  406. index := 1
  407. if opts != nil {
  408. if opts.Skip != nil {
  409. didx, arr = bsoncore.AppendDocumentElementStart(arr, strconv.Itoa(index))
  410. arr = bsoncore.AppendInt64Element(arr, "$skip", *opts.Skip)
  411. arr, _ = bsoncore.AppendDocumentEnd(arr, didx)
  412. index++
  413. }
  414. if opts.Limit != nil {
  415. didx, arr = bsoncore.AppendDocumentElementStart(arr, strconv.Itoa(index))
  416. arr = bsoncore.AppendInt64Element(arr, "$limit", *opts.Limit)
  417. arr, _ = bsoncore.AppendDocumentEnd(arr, didx)
  418. index++
  419. }
  420. }
  421. didx, arr = bsoncore.AppendDocumentElementStart(arr, strconv.Itoa(index))
  422. iidx, arr := bsoncore.AppendDocumentElementStart(arr, "$group")
  423. arr = bsoncore.AppendInt32Element(arr, "_id", 1)
  424. iiidx, arr := bsoncore.AppendDocumentElementStart(arr, "n")
  425. arr = bsoncore.AppendInt32Element(arr, "$sum", 1)
  426. arr, _ = bsoncore.AppendDocumentEnd(arr, iiidx)
  427. arr, _ = bsoncore.AppendDocumentEnd(arr, iidx)
  428. arr, _ = bsoncore.AppendDocumentEnd(arr, didx)
  429. return bsoncore.AppendArrayEnd(arr, aidx)
  430. }