// Copyright (C) MongoDB, Inc. 2017-present. // // Licensed under the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. You may obtain // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 package mongo import ( "context" "errors" "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" "go.mongodb.org/mongo-driver/x/mongo/driver/session" "go.mongodb.org/mongo-driver/x/mongo/driver/topology" ) // ErrWrongClient is returned when a user attempts to pass in a session created by a different client than // the method call is using. var ErrWrongClient = errors.New("session was not created by this client") var withTransactionTimeout = 120 * time.Second // SessionContext is a hybrid interface. It combines a context.Context with // a mongo.Session. This type can be used as a regular context.Context or // Session type. It is not goroutine safe and should not be used in multiple goroutines concurrently. type SessionContext interface { context.Context Session } type sessionContext struct { context.Context Session } type sessionKey struct { } // Session is the interface that represents a sequential set of operations executed. // Instances of this interface can be used to use transactions against the server // and to enable causally consistent behavior for applications. type Session interface { EndSession(context.Context) WithTransaction(ctx context.Context, fn func(sessCtx SessionContext) (interface{}, error), opts ...*options.TransactionOptions) (interface{}, error) StartTransaction(...*options.TransactionOptions) error AbortTransaction(context.Context) error CommitTransaction(context.Context) error ClusterTime() bson.Raw AdvanceClusterTime(bson.Raw) error OperationTime() *primitive.Timestamp AdvanceOperationTime(*primitive.Timestamp) error Client() *Client session() } // sessionImpl represents a set of sequential operations executed by an application that are related in some way. type sessionImpl struct { clientSession *session.Client client *Client topo *topology.Topology didCommitAfterStart bool // true if commit was called after start with no other operations } // EndSession ends the session. func (s *sessionImpl) EndSession(ctx context.Context) { if s.clientSession.TransactionInProgress() { // ignore all errors aborting during an end session _ = s.AbortTransaction(ctx) } s.clientSession.EndSession() } // WithTransaction creates a transaction on this session and runs the given callback, retrying for // TransientTransactionError and UnknownTransactionCommitResult errors. The only way to provide a // session to a CRUD method is to invoke that CRUD method with the mongo.SessionContext within the // callback. The mongo.SessionContext can be used as a regular context, so methods like // context.WithDeadline and context.WithTimeout are supported. // // If the context.Context already has a mongo.Session attached, that mongo.Session will be replaced // with the one provided. // // The callback may be run multiple times due to retry attempts. Non-retryable and timed out errors // are returned from this function. func (s *sessionImpl) WithTransaction(ctx context.Context, fn func(sessCtx SessionContext) (interface{}, error), opts ...*options.TransactionOptions) (interface{}, error) { timeout := time.NewTimer(withTransactionTimeout) defer timeout.Stop() var err error for { err = s.StartTransaction(opts...) if err != nil { return nil, err } res, err := fn(contextWithSession(ctx, s)) if err != nil { if s.clientSession.TransactionRunning() { _ = s.AbortTransaction(ctx) } select { case <-timeout.C: return nil, err default: } if cerr, ok := err.(CommandError); ok { if cerr.HasErrorLabel(driver.TransientTransactionError) { continue } } return res, err } err = s.clientSession.CheckAbortTransaction() if err != nil { return res, nil } CommitLoop: for { err = s.CommitTransaction(ctx) if err == nil { return res, nil } select { case <-timeout.C: return res, err default: } if cerr, ok := err.(CommandError); ok { if cerr.HasErrorLabel(driver.UnknownTransactionCommitResult) && !cerr.IsMaxTimeMSExpiredError() { continue } if cerr.HasErrorLabel(driver.TransientTransactionError) { break CommitLoop } } return res, err } } } // StartTransaction starts a transaction for this session. func (s *sessionImpl) StartTransaction(opts ...*options.TransactionOptions) error { err := s.clientSession.CheckStartTransaction() if err != nil { return err } s.didCommitAfterStart = false topts := options.MergeTransactionOptions(opts...) coreOpts := &session.TransactionOptions{ ReadConcern: topts.ReadConcern, ReadPreference: topts.ReadPreference, WriteConcern: topts.WriteConcern, MaxCommitTime: topts.MaxCommitTime, } return s.clientSession.StartTransaction(coreOpts) } // AbortTransaction aborts the session's transaction, returning any errors and error codes func (s *sessionImpl) AbortTransaction(ctx context.Context) error { err := s.clientSession.CheckAbortTransaction() if err != nil { return err } // Do not run the abort command if the transaction is in starting state if s.clientSession.TransactionStarting() || s.didCommitAfterStart { return s.clientSession.AbortTransaction() } selector := makePinnedSelector(s.clientSession, description.WriteSelector()) s.clientSession.Aborting = true err = operation.NewAbortTransaction().Session(s.clientSession).ClusterClock(s.client.clock).Database("admin"). Deployment(s.topo).WriteConcern(s.clientSession.CurrentWc).ServerSelector(selector). Retry(driver.RetryOncePerCommand).CommandMonitor(s.client.monitor).RecoveryToken(bsoncore.Document(s.clientSession.RecoveryToken)).Execute(ctx) s.clientSession.Aborting = false _ = s.clientSession.AbortTransaction() return replaceErrors(err) } // CommitTransaction commits the sesson's transaction. func (s *sessionImpl) CommitTransaction(ctx context.Context) error { err := s.clientSession.CheckCommitTransaction() if err != nil { return err } // Do not run the commit command if the transaction is in started state if s.clientSession.TransactionStarting() || s.didCommitAfterStart { s.didCommitAfterStart = true return s.clientSession.CommitTransaction() } if s.clientSession.TransactionCommitted() { s.clientSession.RetryingCommit = true } selector := makePinnedSelector(s.clientSession, description.WriteSelector()) s.clientSession.Committing = true op := operation.NewCommitTransaction(). Session(s.clientSession).ClusterClock(s.client.clock).Database("admin").Deployment(s.topo). WriteConcern(s.clientSession.CurrentWc).ServerSelector(selector).Retry(driver.RetryOncePerCommand). CommandMonitor(s.client.monitor).RecoveryToken(bsoncore.Document(s.clientSession.RecoveryToken)) if s.clientSession.CurrentMct != nil { op.MaxTimeMS(int64(*s.clientSession.CurrentMct / time.Millisecond)) } err = op.Execute(ctx) s.clientSession.Committing = false commitErr := s.clientSession.CommitTransaction() // We set the write concern to majority for subsequent calls to CommitTransaction. s.clientSession.UpdateCommitTransactionWriteConcern() if err != nil { return replaceErrors(err) } return commitErr } func (s *sessionImpl) ClusterTime() bson.Raw { return s.clientSession.ClusterTime } func (s *sessionImpl) AdvanceClusterTime(d bson.Raw) error { return s.clientSession.AdvanceClusterTime(d) } func (s *sessionImpl) OperationTime() *primitive.Timestamp { return s.clientSession.OperationTime } func (s *sessionImpl) AdvanceOperationTime(ts *primitive.Timestamp) error { return s.clientSession.AdvanceOperationTime(ts) } func (s *sessionImpl) Client() *Client { return s.client } func (*sessionImpl) session() { } // sessionFromContext checks for a sessionImpl in the argued context and returns the session if it // exists func sessionFromContext(ctx context.Context) *session.Client { s := ctx.Value(sessionKey{}) if ses, ok := s.(*sessionImpl); ses != nil && ok { return ses.clientSession } return nil } func contextWithSession(ctx context.Context, sess Session) SessionContext { return &sessionContext{ Context: context.WithValue(ctx, sessionKey{}, sess), Session: sess, } }