// Copyright (C) MongoDB, Inc. 2017-present. // // Licensed under the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. You may obtain // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 package mongo import ( "context" "crypto/tls" "strconv" "strings" "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/bsoncodec" "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/auth" "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" "go.mongodb.org/mongo-driver/x/mongo/driver/session" "go.mongodb.org/mongo-driver/x/mongo/driver/topology" "go.mongodb.org/mongo-driver/x/mongo/driver/uuid" ) const defaultLocalThreshold = 15 * time.Millisecond const batchSize = 10000 // Client performs operations on a given topology. type Client struct { id uuid.UUID topologyOptions []topology.Option topology *topology.Topology connString connstring.ConnString localThreshold time.Duration retryWrites bool retryReads bool clock *session.ClusterClock readPreference *readpref.ReadPref readConcern *readconcern.ReadConcern writeConcern *writeconcern.WriteConcern registry *bsoncodec.Registry marshaller BSONAppender monitor *event.CommandMonitor } // Connect creates a new Client and then initializes it using the Connect method. func Connect(ctx context.Context, opts ...*options.ClientOptions) (*Client, error) { c, err := NewClient(opts...) if err != nil { return nil, err } err = c.Connect(ctx) if err != nil { return nil, err } return c, nil } // NewClient creates a new client to connect to a cluster specified by the uri. // // When creating an options.ClientOptions, the order the methods are called matters. Later Set* // methods will overwrite the values from previous Set* method invocations. This includes the // ApplyURI method. This allows callers to determine the order of precedence for option // application. For instance, if ApplyURI is called before SetAuth, the Credential from // SetAuth will overwrite the values from the connection string. If ApplyURI is called // after SetAuth, then its values will overwrite those from SetAuth. // // The opts parameter is processed using options.MergeClientOptions, which will overwrite entire // option fields of previous options, there is no partial overwriting. For example, if Username is // set in the Auth field for the first option, and Password is set for the second but with no // Username, after the merge the Username field will be empty. func NewClient(opts ...*options.ClientOptions) (*Client, error) { clientOpt := options.MergeClientOptions(opts...) id, err := uuid.New() if err != nil { return nil, err } client := &Client{id: id} err = client.configure(clientOpt) if err != nil { return nil, err } client.topology, err = topology.New(client.topologyOptions...) if err != nil { return nil, replaceErrors(err) } return client, nil } // Connect initializes the Client by starting background monitoring goroutines. // This method must be called before a Client can be used. func (c *Client) Connect(ctx context.Context) error { err := c.topology.Connect() if err != nil { return replaceErrors(err) } return nil } // Disconnect closes sockets to the topology referenced by this Client. It will // shut down any monitoring goroutines, close the idle connection pool, and will // wait until all the in use connections have been returned to the connection // pool and closed before returning. If the context expires via cancellation, // deadline, or timeout before the in use connections have returned, the in use // connections will be closed, resulting in the failure of any in flight read // or write operations. If this method returns with no errors, all connections // associated with this Client have been closed. func (c *Client) Disconnect(ctx context.Context) error { if ctx == nil { ctx = context.Background() } c.endSessions(ctx) return replaceErrors(c.topology.Disconnect(ctx)) } // Ping verifies that the client can connect to the topology. // If readPreference is nil then will use the client's default read // preference. func (c *Client) Ping(ctx context.Context, rp *readpref.ReadPref) error { if ctx == nil { ctx = context.Background() } if rp == nil { rp = c.readPreference } db := c.Database("admin") res := db.RunCommand(ctx, bson.D{ {"ping", 1}, }, options.RunCmd().SetReadPreference(rp)) return replaceErrors(res.Err()) } // StartSession starts a new session. func (c *Client) StartSession(opts ...*options.SessionOptions) (Session, error) { if c.topology.SessionPool == nil { return nil, ErrClientDisconnected } sopts := options.MergeSessionOptions(opts...) coreOpts := &session.ClientOptions{ DefaultReadConcern: c.readConcern, DefaultReadPreference: c.readPreference, DefaultWriteConcern: c.writeConcern, } if sopts.CausalConsistency != nil { coreOpts.CausalConsistency = sopts.CausalConsistency } if sopts.DefaultReadConcern != nil { coreOpts.DefaultReadConcern = sopts.DefaultReadConcern } if sopts.DefaultWriteConcern != nil { coreOpts.DefaultWriteConcern = sopts.DefaultWriteConcern } if sopts.DefaultReadPreference != nil { coreOpts.DefaultReadPreference = sopts.DefaultReadPreference } if sopts.DefaultMaxCommitTime != nil { coreOpts.DefaultMaxCommitTime = sopts.DefaultMaxCommitTime } sess, err := session.NewClientSession(c.topology.SessionPool, c.id, session.Explicit, coreOpts) if err != nil { return nil, replaceErrors(err) } sess.RetryWrite = c.retryWrites sess.RetryRead = c.retryReads return &sessionImpl{ clientSession: sess, client: c, topo: c.topology, }, nil } func (c *Client) endSessions(ctx context.Context) { if c.topology.SessionPool == nil { return } ids := c.topology.SessionPool.IDSlice() idx, idArray := bsoncore.AppendArrayStart(nil) for i, id := range ids { idDoc, _ := id.MarshalBSON() idArray = bsoncore.AppendDocumentElement(idArray, strconv.Itoa(i), idDoc) } idArray, _ = bsoncore.AppendArrayEnd(idArray, idx) op := operation.NewEndSessions(idArray).ClusterClock(c.clock).Deployment(c.topology). ServerSelector(description.ReadPrefSelector(readpref.PrimaryPreferred())).CommandMonitor(c.monitor).Database("admin") idx, idArray = bsoncore.AppendArrayStart(nil) totalNumIDs := len(ids) for i := 0; i < totalNumIDs; i++ { idDoc, _ := ids[i].MarshalBSON() idArray = bsoncore.AppendDocumentElement(idArray, strconv.Itoa(i), idDoc) if ((i+1)%batchSize) == 0 || i == totalNumIDs-1 { idArray, _ = bsoncore.AppendArrayEnd(idArray, idx) _ = op.SessionIDs(idArray).Execute(ctx) idArray = idArray[:0] idx = 0 } } } func (c *Client) configure(opts *options.ClientOptions) error { if err := opts.Validate(); err != nil { return err } var connOpts []topology.ConnectionOption var serverOpts []topology.ServerOption var topologyOpts []topology.Option // TODO(GODRIVER-814): Add tests for topology, server, and connection related options. // AppName var appName string if opts.AppName != nil { appName = *opts.AppName } // Compressors & ZlibLevel var comps []string if len(opts.Compressors) > 0 { comps = opts.Compressors connOpts = append(connOpts, topology.WithCompressors( func(compressors []string) []string { return append(compressors, comps...) }, )) for _, comp := range comps { if comp == "zlib" { connOpts = append(connOpts, topology.WithZlibLevel(func(level *int) *int { return opts.ZlibLevel })) } } serverOpts = append(serverOpts, topology.WithCompressionOptions( func(opts ...string) []string { return append(opts, comps...) }, )) } // Handshaker var handshaker = func(driver.Handshaker) driver.Handshaker { return operation.NewIsMaster().AppName(appName).Compressors(comps) } // Auth & Database & Password & Username if opts.Auth != nil { cred := &auth.Cred{ Username: opts.Auth.Username, Password: opts.Auth.Password, PasswordSet: opts.Auth.PasswordSet, Props: opts.Auth.AuthMechanismProperties, Source: opts.Auth.AuthSource, } mechanism := opts.Auth.AuthMechanism if len(cred.Source) == 0 { switch strings.ToUpper(mechanism) { case auth.MongoDBX509, auth.GSSAPI, auth.PLAIN: cred.Source = "$external" default: cred.Source = "admin" } } authenticator, err := auth.CreateAuthenticator(mechanism, cred) if err != nil { return err } handshakeOpts := &auth.HandshakeOptions{ AppName: appName, Authenticator: authenticator, Compressors: comps, } if mechanism == "" { // Required for SASL mechanism negotiation during handshake handshakeOpts.DBUser = cred.Source + "." + cred.Username } if opts.AuthenticateToAnything != nil && *opts.AuthenticateToAnything { // Authenticate arbiters handshakeOpts.PerformAuthentication = func(serv description.Server) bool { return true } } handshaker = func(driver.Handshaker) driver.Handshaker { return auth.Handshaker(nil, handshakeOpts) } } connOpts = append(connOpts, topology.WithHandshaker(handshaker)) // ConnectTimeout if opts.ConnectTimeout != nil { serverOpts = append(serverOpts, topology.WithHeartbeatTimeout( func(time.Duration) time.Duration { return *opts.ConnectTimeout }, )) connOpts = append(connOpts, topology.WithConnectTimeout( func(time.Duration) time.Duration { return *opts.ConnectTimeout }, )) } // Dialer if opts.Dialer != nil { connOpts = append(connOpts, topology.WithDialer( func(topology.Dialer) topology.Dialer { return opts.Dialer }, )) } // Direct if opts.Direct != nil && *opts.Direct { topologyOpts = append(topologyOpts, topology.WithMode( func(topology.MonitorMode) topology.MonitorMode { return topology.SingleMode }, )) } // HeartbeatInterval if opts.HeartbeatInterval != nil { serverOpts = append(serverOpts, topology.WithHeartbeatInterval( func(time.Duration) time.Duration { return *opts.HeartbeatInterval }, )) } // Hosts hosts := []string{"localhost:27017"} // default host if len(opts.Hosts) > 0 { hosts = opts.Hosts } topologyOpts = append(topologyOpts, topology.WithSeedList( func(...string) []string { return hosts }, )) // LocalThreshold c.localThreshold = defaultLocalThreshold if opts.LocalThreshold != nil { c.localThreshold = *opts.LocalThreshold } // MaxConIdleTime if opts.MaxConnIdleTime != nil { connOpts = append(connOpts, topology.WithIdleTimeout( func(time.Duration) time.Duration { return *opts.MaxConnIdleTime }, )) } // MaxPoolSize if opts.MaxPoolSize != nil { serverOpts = append( serverOpts, topology.WithMaxConnections(func(uint64) uint64 { return *opts.MaxPoolSize }), ) } // MinPoolSize if opts.MinPoolSize != nil { serverOpts = append( serverOpts, topology.WithMinConnections(func(uint64) uint64 { return *opts.MinPoolSize }), ) } // PoolMonitor if opts.PoolMonitor != nil { serverOpts = append( serverOpts, topology.WithConnectionPoolMonitor(func(*event.PoolMonitor) *event.PoolMonitor { return opts.PoolMonitor }), ) } // Monitor if opts.Monitor != nil { c.monitor = opts.Monitor connOpts = append(connOpts, topology.WithMonitor( func(*event.CommandMonitor) *event.CommandMonitor { return opts.Monitor }, )) } // ReadConcern c.readConcern = readconcern.New() if opts.ReadConcern != nil { c.readConcern = opts.ReadConcern } // ReadPreference c.readPreference = readpref.Primary() if opts.ReadPreference != nil { c.readPreference = opts.ReadPreference } // Registry c.registry = bson.DefaultRegistry if opts.Registry != nil { c.registry = opts.Registry } // ReplicaSet if opts.ReplicaSet != nil { topologyOpts = append(topologyOpts, topology.WithReplicaSetName( func(string) string { return *opts.ReplicaSet }, )) } // RetryWrites c.retryWrites = true // retry writes on by default if opts.RetryWrites != nil { c.retryWrites = *opts.RetryWrites } c.retryReads = true if opts.RetryReads != nil { c.retryReads = *opts.RetryReads } // ServerSelectionTimeout if opts.ServerSelectionTimeout != nil { topologyOpts = append(topologyOpts, topology.WithServerSelectionTimeout( func(time.Duration) time.Duration { return *opts.ServerSelectionTimeout }, )) } // SocketTimeout if opts.SocketTimeout != nil { connOpts = append( connOpts, topology.WithReadTimeout(func(time.Duration) time.Duration { return *opts.SocketTimeout }), topology.WithWriteTimeout(func(time.Duration) time.Duration { return *opts.SocketTimeout }), ) } // TLSConfig if opts.TLSConfig != nil { connOpts = append(connOpts, topology.WithTLSConfig( func(*tls.Config) *tls.Config { return opts.TLSConfig }, )) } // WriteConcern if opts.WriteConcern != nil { c.writeConcern = opts.WriteConcern } // ClusterClock c.clock = new(session.ClusterClock) serverOpts = append( serverOpts, topology.WithClock(func(*session.ClusterClock) *session.ClusterClock { return c.clock }), topology.WithConnectionOptions(func(...topology.ConnectionOption) []topology.ConnectionOption { return connOpts }), ) c.topologyOptions = append(topologyOpts, topology.WithServerOptions( func(...topology.ServerOption) []topology.ServerOption { return serverOpts }, )) return nil } // validSession returns an error if the session doesn't belong to the client func (c *Client) validSession(sess *session.Client) error { if sess != nil && !uuid.Equal(sess.ClientID, c.id) { return ErrWrongClient } return nil } // Database returns a handle for a given database. func (c *Client) Database(name string, opts ...*options.DatabaseOptions) *Database { return newDatabase(c, name, opts...) } // ListDatabases returns a ListDatabasesResult. func (c *Client) ListDatabases(ctx context.Context, filter interface{}, opts ...*options.ListDatabasesOptions) (ListDatabasesResult, error) { if ctx == nil { ctx = context.Background() } sess := sessionFromContext(ctx) err := c.validSession(sess) if sess == nil && c.topology.SessionPool != nil { sess, err = session.NewClientSession(c.topology.SessionPool, c.id, session.Implicit) if err != nil { return ListDatabasesResult{}, err } defer sess.EndSession() } err = c.validSession(sess) if err != nil { return ListDatabasesResult{}, err } filterDoc, err := transformBsoncoreDocument(c.registry, filter) if err != nil { return ListDatabasesResult{}, err } selector := makePinnedSelector(sess, description.CompositeSelector([]description.ServerSelector{ description.ReadPrefSelector(readpref.Primary()), description.LatencySelector(c.localThreshold), })) ldo := options.MergeListDatabasesOptions(opts...) op := operation.NewListDatabases(filterDoc). Session(sess).ReadPreference(c.readPreference).CommandMonitor(c.monitor). ServerSelector(selector).ClusterClock(c.clock).Database("admin").Deployment(c.topology) if ldo.NameOnly != nil { op = op.NameOnly(*ldo.NameOnly) } retry := driver.RetryNone if c.retryReads { retry = driver.RetryOncePerCommand } op.Retry(retry) err = op.Execute(ctx) if err != nil { return ListDatabasesResult{}, replaceErrors(err) } return newListDatabasesResultFromOperation(op.Result()), nil } // ListDatabaseNames returns a slice containing the names of all of the databases on the server. func (c *Client) ListDatabaseNames(ctx context.Context, filter interface{}, opts ...*options.ListDatabasesOptions) ([]string, error) { opts = append(opts, options.ListDatabases().SetNameOnly(true)) res, err := c.ListDatabases(ctx, filter, opts...) if err != nil { return nil, err } names := make([]string, 0) for _, spec := range res.Databases { names = append(names, spec.Name) } return names, nil } // WithSession allows a user to start a session themselves and manage // its lifetime. The only way to provide a session to a CRUD method is // to invoke that CRUD method with the mongo.SessionContext within the // closure. The mongo.SessionContext can be used as a regular context, // so methods like context.WithDeadline and context.WithTimeout are // supported. // // If the context.Context already has a mongo.Session attached, that // mongo.Session will be replaced with the one provided. // // Errors returned from the closure are transparently returned from // this function. func WithSession(ctx context.Context, sess Session, fn func(SessionContext) error) error { return fn(contextWithSession(ctx, sess)) } // UseSession creates a default session, that is only valid for the // lifetime of the closure. No cleanup outside of closing the session // is done upon exiting the closure. This means that an outstanding // transaction will be aborted, even if the closure returns an error. // // If ctx already contains a mongo.Session, that mongo.Session will be // replaced with the newly created mongo.Session. // // Errors returned from the closure are transparently returned from // this method. func (c *Client) UseSession(ctx context.Context, fn func(SessionContext) error) error { return c.UseSessionWithOptions(ctx, options.Session(), fn) } // UseSessionWithOptions works like UseSession but allows the caller // to specify the options used to create the session. func (c *Client) UseSessionWithOptions(ctx context.Context, opts *options.SessionOptions, fn func(SessionContext) error) error { defaultSess, err := c.StartSession(opts) if err != nil { return err } defer defaultSess.EndSession(ctx) sessCtx := sessionContext{ Context: context.WithValue(ctx, sessionKey{}, defaultSess), Session: defaultSess, } return fn(sessCtx) } // Watch returns a change stream cursor used to receive information of changes to the client. This method is preferred // to running a raw aggregation with a $changeStream stage because it supports resumability in the case of some errors. // The client must have read concern majority or no read concern for a change stream to be created successfully. func (c *Client) Watch(ctx context.Context, pipeline interface{}, opts ...*options.ChangeStreamOptions) (*ChangeStream, error) { if c.topology.SessionPool == nil { return nil, ErrClientDisconnected } csConfig := changeStreamConfig{ readConcern: c.readConcern, readPreference: c.readPreference, client: c, registry: c.registry, streamType: ClientStream, } return newChangeStream(ctx, csConfig, pipeline, opts...) }