* Update go-ldap dependency * Request for public keys only if attribute is setmaster
@@ -1006,12 +1006,12 @@ | |||||
version = "v1.31.1" | version = "v1.31.1" | ||||
[[projects]] | [[projects]] | ||||
digest = "1:01f4ac37c52bda6f7e1bd73680a99f88733c0408aaa159ecb1ba53a1ade9423c" | |||||
digest = "1:7e1c00b9959544fa1ccca7cf0407a5b29ac6d5201059c4fac6f599cb99bfd24d" | |||||
name = "gopkg.in/ldap.v2" | name = "gopkg.in/ldap.v2" | ||||
packages = ["."] | packages = ["."] | ||||
pruneopts = "NUT" | pruneopts = "NUT" | ||||
revision = "d0a5ced67b4dc310b9158d63a2c6f9c5ec13f105" | |||||
version = "v2.4.1" | |||||
revision = "bb7a9ca6e4fbc2129e3db588a34bc970ffe811a9" | |||||
version = "v2.5.1" | |||||
[[projects]] | [[projects]] | ||||
digest = "1:cfe1730a152ff033ad7d9c115d22e36b19eec6d5928c06146b9119be45d39dc0" | digest = "1:cfe1730a152ff033ad7d9c115d22e36b19eec6d5928c06146b9119be45d39dc0" | ||||
@@ -1174,6 +1174,7 @@ | |||||
"github.com/keybase/go-crypto/openpgp", | "github.com/keybase/go-crypto/openpgp", | ||||
"github.com/keybase/go-crypto/openpgp/armor", | "github.com/keybase/go-crypto/openpgp/armor", | ||||
"github.com/keybase/go-crypto/openpgp/packet", | "github.com/keybase/go-crypto/openpgp/packet", | ||||
"github.com/klauspost/compress/gzip", | |||||
"github.com/lafriks/xormstore", | "github.com/lafriks/xormstore", | ||||
"github.com/lib/pq", | "github.com/lib/pq", | ||||
"github.com/lunny/dingtalk_webhook", | "github.com/lunny/dingtalk_webhook", | ||||
@@ -247,11 +247,17 @@ func (ls *Source) SearchEntry(name, passwd string, directBind bool) *SearchResul | |||||
return nil | return nil | ||||
} | } | ||||
var isAttributeSSHPublicKeySet = len(strings.TrimSpace(ls.AttributeSSHPublicKey)) > 0 | |||||
attribs := []string{ls.AttributeUsername, ls.AttributeName, ls.AttributeSurname, ls.AttributeMail} | |||||
if isAttributeSSHPublicKeySet { | |||||
attribs = append(attribs, ls.AttributeSSHPublicKey) | |||||
} | |||||
log.Trace("Fetching attributes '%v', '%v', '%v', '%v', '%v' with filter %s and base %s", ls.AttributeUsername, ls.AttributeName, ls.AttributeSurname, ls.AttributeMail, ls.AttributeSSHPublicKey, userFilter, userDN) | log.Trace("Fetching attributes '%v', '%v', '%v', '%v', '%v' with filter %s and base %s", ls.AttributeUsername, ls.AttributeName, ls.AttributeSurname, ls.AttributeMail, ls.AttributeSSHPublicKey, userFilter, userDN) | ||||
search := ldap.NewSearchRequest( | search := ldap.NewSearchRequest( | ||||
userDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, userFilter, | userDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, userFilter, | ||||
[]string{ls.AttributeUsername, ls.AttributeName, ls.AttributeSurname, ls.AttributeMail, ls.AttributeSSHPublicKey}, | |||||
nil) | |||||
attribs, nil) | |||||
sr, err := l.Search(search) | sr, err := l.Search(search) | ||||
if err != nil { | if err != nil { | ||||
@@ -267,11 +273,15 @@ func (ls *Source) SearchEntry(name, passwd string, directBind bool) *SearchResul | |||||
return nil | return nil | ||||
} | } | ||||
var sshPublicKey []string | |||||
username := sr.Entries[0].GetAttributeValue(ls.AttributeUsername) | username := sr.Entries[0].GetAttributeValue(ls.AttributeUsername) | ||||
firstname := sr.Entries[0].GetAttributeValue(ls.AttributeName) | firstname := sr.Entries[0].GetAttributeValue(ls.AttributeName) | ||||
surname := sr.Entries[0].GetAttributeValue(ls.AttributeSurname) | surname := sr.Entries[0].GetAttributeValue(ls.AttributeSurname) | ||||
mail := sr.Entries[0].GetAttributeValue(ls.AttributeMail) | mail := sr.Entries[0].GetAttributeValue(ls.AttributeMail) | ||||
sshPublicKey := sr.Entries[0].GetAttributeValues(ls.AttributeSSHPublicKey) | |||||
if isAttributeSSHPublicKeySet { | |||||
sshPublicKey = sr.Entries[0].GetAttributeValues(ls.AttributeSSHPublicKey) | |||||
} | |||||
isAdmin := checkAdmin(l, ls, userDN) | isAdmin := checkAdmin(l, ls, userDN) | ||||
if !directBind && ls.AttributesInBind { | if !directBind && ls.AttributesInBind { | ||||
@@ -320,11 +330,17 @@ func (ls *Source) SearchEntries() []*SearchResult { | |||||
userFilter := fmt.Sprintf(ls.Filter, "*") | userFilter := fmt.Sprintf(ls.Filter, "*") | ||||
var isAttributeSSHPublicKeySet = len(strings.TrimSpace(ls.AttributeSSHPublicKey)) > 0 | |||||
attribs := []string{ls.AttributeUsername, ls.AttributeName, ls.AttributeSurname, ls.AttributeMail} | |||||
if isAttributeSSHPublicKeySet { | |||||
attribs = append(attribs, ls.AttributeSSHPublicKey) | |||||
} | |||||
log.Trace("Fetching attributes '%v', '%v', '%v', '%v', '%v' with filter %s and base %s", ls.AttributeUsername, ls.AttributeName, ls.AttributeSurname, ls.AttributeMail, ls.AttributeSSHPublicKey, userFilter, ls.UserBase) | log.Trace("Fetching attributes '%v', '%v', '%v', '%v', '%v' with filter %s and base %s", ls.AttributeUsername, ls.AttributeName, ls.AttributeSurname, ls.AttributeMail, ls.AttributeSSHPublicKey, userFilter, ls.UserBase) | ||||
search := ldap.NewSearchRequest( | search := ldap.NewSearchRequest( | ||||
ls.UserBase, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, userFilter, | ls.UserBase, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, userFilter, | ||||
[]string{ls.AttributeUsername, ls.AttributeName, ls.AttributeSurname, ls.AttributeMail, ls.AttributeSSHPublicKey}, | |||||
nil) | |||||
attribs, nil) | |||||
var sr *ldap.SearchResult | var sr *ldap.SearchResult | ||||
if ls.UsePagedSearch() { | if ls.UsePagedSearch() { | ||||
@@ -341,12 +357,14 @@ func (ls *Source) SearchEntries() []*SearchResult { | |||||
for i, v := range sr.Entries { | for i, v := range sr.Entries { | ||||
result[i] = &SearchResult{ | result[i] = &SearchResult{ | ||||
Username: v.GetAttributeValue(ls.AttributeUsername), | |||||
Name: v.GetAttributeValue(ls.AttributeName), | |||||
Surname: v.GetAttributeValue(ls.AttributeSurname), | |||||
Mail: v.GetAttributeValue(ls.AttributeMail), | |||||
SSHPublicKey: v.GetAttributeValues(ls.AttributeSSHPublicKey), | |||||
IsAdmin: checkAdmin(l, ls, v.DN), | |||||
Username: v.GetAttributeValue(ls.AttributeUsername), | |||||
Name: v.GetAttributeValue(ls.AttributeName), | |||||
Surname: v.GetAttributeValue(ls.AttributeSurname), | |||||
Mail: v.GetAttributeValue(ls.AttributeMail), | |||||
IsAdmin: checkAdmin(l, ls, v.DN), | |||||
} | |||||
if isAttributeSSHPublicKeySet { | |||||
result[i].SSHPublicKey = v.GetAttributeValues(ls.AttributeSSHPublicKey) | |||||
} | } | ||||
} | } | ||||
@@ -1,27 +1,22 @@ | |||||
Copyright (c) 2012 The Go Authors. All rights reserved. | |||||
The MIT License (MIT) | |||||
Redistribution and use in source and binary forms, with or without | |||||
modification, are permitted provided that the following conditions are | |||||
met: | |||||
Copyright (c) 2011-2015 Michael Mitton (mmitton@gmail.com) | |||||
Portions copyright (c) 2015-2016 go-ldap Authors | |||||
* Redistributions of source code must retain the above copyright | |||||
notice, this list of conditions and the following disclaimer. | |||||
* Redistributions in binary form must reproduce the above | |||||
copyright notice, this list of conditions and the following disclaimer | |||||
in the documentation and/or other materials provided with the | |||||
distribution. | |||||
* Neither the name of Google Inc. nor the names of its | |||||
contributors may be used to endorse or promote products derived from | |||||
this software without specific prior written permission. | |||||
Permission is hereby granted, free of charge, to any person obtaining a copy | |||||
of this software and associated documentation files (the "Software"), to deal | |||||
in the Software without restriction, including without limitation the rights | |||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |||||
copies of the Software, and to permit persons to whom the Software is | |||||
furnished to do so, subject to the following conditions: | |||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS | |||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT | |||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR | |||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT | |||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, | |||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT | |||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY | |||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | |||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |||||
The above copyright notice and this permission notice shall be included in all | |||||
copies or substantial portions of the Software. | |||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |||||
SOFTWARE. |
@@ -0,0 +1,13 @@ | |||||
// +build go1.4 | |||||
package ldap | |||||
import ( | |||||
"sync/atomic" | |||||
) | |||||
// For compilers that support it, we just use the underlying sync/atomic.Value | |||||
// type. | |||||
type atomicValue struct { | |||||
atomic.Value | |||||
} |
@@ -0,0 +1,28 @@ | |||||
// +build !go1.4 | |||||
package ldap | |||||
import ( | |||||
"sync" | |||||
) | |||||
// This is a helper type that emulates the use of the "sync/atomic.Value" | |||||
// struct that's available in Go 1.4 and up. | |||||
type atomicValue struct { | |||||
value interface{} | |||||
lock sync.RWMutex | |||||
} | |||||
func (av *atomicValue) Store(val interface{}) { | |||||
av.lock.Lock() | |||||
av.value = val | |||||
av.lock.Unlock() | |||||
} | |||||
func (av *atomicValue) Load() interface{} { | |||||
av.lock.RLock() | |||||
ret := av.value | |||||
av.lock.RUnlock() | |||||
return ret | |||||
} |
@@ -11,6 +11,7 @@ import ( | |||||
"log" | "log" | ||||
"net" | "net" | ||||
"sync" | "sync" | ||||
"sync/atomic" | |||||
"time" | "time" | ||||
"gopkg.in/asn1-ber.v1" | "gopkg.in/asn1-ber.v1" | ||||
@@ -82,20 +83,18 @@ const ( | |||||
type Conn struct { | type Conn struct { | ||||
conn net.Conn | conn net.Conn | ||||
isTLS bool | isTLS bool | ||||
isClosing bool | |||||
closeErr error | |||||
closing uint32 | |||||
closeErr atomicValue | |||||
isStartingTLS bool | isStartingTLS bool | ||||
Debug debugging | Debug debugging | ||||
chanConfirm chan bool | |||||
chanConfirm chan struct{} | |||||
messageContexts map[int64]*messageContext | messageContexts map[int64]*messageContext | ||||
chanMessage chan *messagePacket | chanMessage chan *messagePacket | ||||
chanMessageID chan int64 | chanMessageID chan int64 | ||||
wgSender sync.WaitGroup | |||||
wgClose sync.WaitGroup | wgClose sync.WaitGroup | ||||
once sync.Once | |||||
outstandingRequests uint | outstandingRequests uint | ||||
messageMutex sync.Mutex | messageMutex sync.Mutex | ||||
requestTimeout time.Duration | |||||
requestTimeout int64 | |||||
} | } | ||||
var _ Client = &Conn{} | var _ Client = &Conn{} | ||||
@@ -142,7 +141,7 @@ func DialTLS(network, addr string, config *tls.Config) (*Conn, error) { | |||||
func NewConn(conn net.Conn, isTLS bool) *Conn { | func NewConn(conn net.Conn, isTLS bool) *Conn { | ||||
return &Conn{ | return &Conn{ | ||||
conn: conn, | conn: conn, | ||||
chanConfirm: make(chan bool), | |||||
chanConfirm: make(chan struct{}), | |||||
chanMessageID: make(chan int64), | chanMessageID: make(chan int64), | ||||
chanMessage: make(chan *messagePacket, 10), | chanMessage: make(chan *messagePacket, 10), | ||||
messageContexts: map[int64]*messageContext{}, | messageContexts: map[int64]*messageContext{}, | ||||
@@ -158,12 +157,22 @@ func (l *Conn) Start() { | |||||
l.wgClose.Add(1) | l.wgClose.Add(1) | ||||
} | } | ||||
// isClosing returns whether or not we're currently closing. | |||||
func (l *Conn) isClosing() bool { | |||||
return atomic.LoadUint32(&l.closing) == 1 | |||||
} | |||||
// setClosing sets the closing value to true | |||||
func (l *Conn) setClosing() bool { | |||||
return atomic.CompareAndSwapUint32(&l.closing, 0, 1) | |||||
} | |||||
// Close closes the connection. | // Close closes the connection. | ||||
func (l *Conn) Close() { | func (l *Conn) Close() { | ||||
l.once.Do(func() { | |||||
l.isClosing = true | |||||
l.wgSender.Wait() | |||||
l.messageMutex.Lock() | |||||
defer l.messageMutex.Unlock() | |||||
if l.setClosing() { | |||||
l.Debug.Printf("Sending quit message and waiting for confirmation") | l.Debug.Printf("Sending quit message and waiting for confirmation") | ||||
l.chanMessage <- &messagePacket{Op: MessageQuit} | l.chanMessage <- &messagePacket{Op: MessageQuit} | ||||
<-l.chanConfirm | <-l.chanConfirm | ||||
@@ -171,27 +180,25 @@ func (l *Conn) Close() { | |||||
l.Debug.Printf("Closing network connection") | l.Debug.Printf("Closing network connection") | ||||
if err := l.conn.Close(); err != nil { | if err := l.conn.Close(); err != nil { | ||||
log.Print(err) | |||||
log.Println(err) | |||||
} | } | ||||
l.wgClose.Done() | l.wgClose.Done() | ||||
}) | |||||
} | |||||
l.wgClose.Wait() | l.wgClose.Wait() | ||||
} | } | ||||
// SetTimeout sets the time after a request is sent that a MessageTimeout triggers | // SetTimeout sets the time after a request is sent that a MessageTimeout triggers | ||||
func (l *Conn) SetTimeout(timeout time.Duration) { | func (l *Conn) SetTimeout(timeout time.Duration) { | ||||
if timeout > 0 { | if timeout > 0 { | ||||
l.requestTimeout = timeout | |||||
atomic.StoreInt64(&l.requestTimeout, int64(timeout)) | |||||
} | } | ||||
} | } | ||||
// Returns the next available messageID | // Returns the next available messageID | ||||
func (l *Conn) nextMessageID() int64 { | func (l *Conn) nextMessageID() int64 { | ||||
if l.chanMessageID != nil { | |||||
if messageID, ok := <-l.chanMessageID; ok { | |||||
return messageID | |||||
} | |||||
if messageID, ok := <-l.chanMessageID; ok { | |||||
return messageID | |||||
} | } | ||||
return 0 | return 0 | ||||
} | } | ||||
@@ -258,7 +265,7 @@ func (l *Conn) sendMessage(packet *ber.Packet) (*messageContext, error) { | |||||
} | } | ||||
func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (*messageContext, error) { | func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (*messageContext, error) { | ||||
if l.isClosing { | |||||
if l.isClosing() { | |||||
return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed")) | return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed")) | ||||
} | } | ||||
l.messageMutex.Lock() | l.messageMutex.Lock() | ||||
@@ -297,7 +304,7 @@ func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) | |||||
func (l *Conn) finishMessage(msgCtx *messageContext) { | func (l *Conn) finishMessage(msgCtx *messageContext) { | ||||
close(msgCtx.done) | close(msgCtx.done) | ||||
if l.isClosing { | |||||
if l.isClosing() { | |||||
return | return | ||||
} | } | ||||
@@ -316,12 +323,12 @@ func (l *Conn) finishMessage(msgCtx *messageContext) { | |||||
} | } | ||||
func (l *Conn) sendProcessMessage(message *messagePacket) bool { | func (l *Conn) sendProcessMessage(message *messagePacket) bool { | ||||
if l.isClosing { | |||||
l.messageMutex.Lock() | |||||
defer l.messageMutex.Unlock() | |||||
if l.isClosing() { | |||||
return false | return false | ||||
} | } | ||||
l.wgSender.Add(1) | |||||
l.chanMessage <- message | l.chanMessage <- message | ||||
l.wgSender.Done() | |||||
return true | return true | ||||
} | } | ||||
@@ -333,15 +340,14 @@ func (l *Conn) processMessages() { | |||||
for messageID, msgCtx := range l.messageContexts { | for messageID, msgCtx := range l.messageContexts { | ||||
// If we are closing due to an error, inform anyone who | // If we are closing due to an error, inform anyone who | ||||
// is waiting about the error. | // is waiting about the error. | ||||
if l.isClosing && l.closeErr != nil { | |||||
msgCtx.sendResponse(&PacketResponse{Error: l.closeErr}) | |||||
if l.isClosing() && l.closeErr.Load() != nil { | |||||
msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)}) | |||||
} | } | ||||
l.Debug.Printf("Closing channel for MessageID %d", messageID) | l.Debug.Printf("Closing channel for MessageID %d", messageID) | ||||
close(msgCtx.responses) | close(msgCtx.responses) | ||||
delete(l.messageContexts, messageID) | delete(l.messageContexts, messageID) | ||||
} | } | ||||
close(l.chanMessageID) | close(l.chanMessageID) | ||||
l.chanConfirm <- true | |||||
close(l.chanConfirm) | close(l.chanConfirm) | ||||
}() | }() | ||||
@@ -350,11 +356,7 @@ func (l *Conn) processMessages() { | |||||
select { | select { | ||||
case l.chanMessageID <- messageID: | case l.chanMessageID <- messageID: | ||||
messageID++ | messageID++ | ||||
case message, ok := <-l.chanMessage: | |||||
if !ok { | |||||
l.Debug.Printf("Shutting down - message channel is closed") | |||||
return | |||||
} | |||||
case message := <-l.chanMessage: | |||||
switch message.Op { | switch message.Op { | ||||
case MessageQuit: | case MessageQuit: | ||||
l.Debug.Printf("Shutting down - quit message received") | l.Debug.Printf("Shutting down - quit message received") | ||||
@@ -377,14 +379,15 @@ func (l *Conn) processMessages() { | |||||
l.messageContexts[message.MessageID] = message.Context | l.messageContexts[message.MessageID] = message.Context | ||||
// Add timeout if defined | // Add timeout if defined | ||||
if l.requestTimeout > 0 { | |||||
requestTimeout := time.Duration(atomic.LoadInt64(&l.requestTimeout)) | |||||
if requestTimeout > 0 { | |||||
go func() { | go func() { | ||||
defer func() { | defer func() { | ||||
if err := recover(); err != nil { | if err := recover(); err != nil { | ||||
log.Printf("ldap: recovered panic in RequestTimeout: %v", err) | log.Printf("ldap: recovered panic in RequestTimeout: %v", err) | ||||
} | } | ||||
}() | }() | ||||
time.Sleep(l.requestTimeout) | |||||
time.Sleep(requestTimeout) | |||||
timeoutMessage := &messagePacket{ | timeoutMessage := &messagePacket{ | ||||
Op: MessageTimeout, | Op: MessageTimeout, | ||||
MessageID: message.MessageID, | MessageID: message.MessageID, | ||||
@@ -397,7 +400,7 @@ func (l *Conn) processMessages() { | |||||
if msgCtx, ok := l.messageContexts[message.MessageID]; ok { | if msgCtx, ok := l.messageContexts[message.MessageID]; ok { | ||||
msgCtx.sendResponse(&PacketResponse{message.Packet, nil}) | msgCtx.sendResponse(&PacketResponse{message.Packet, nil}) | ||||
} else { | } else { | ||||
log.Printf("Received unexpected message %d, %v", message.MessageID, l.isClosing) | |||||
log.Printf("Received unexpected message %d, %v", message.MessageID, l.isClosing()) | |||||
ber.PrintPacket(message.Packet) | ber.PrintPacket(message.Packet) | ||||
} | } | ||||
case MessageTimeout: | case MessageTimeout: | ||||
@@ -439,8 +442,8 @@ func (l *Conn) reader() { | |||||
packet, err := ber.ReadPacket(l.conn) | packet, err := ber.ReadPacket(l.conn) | ||||
if err != nil { | if err != nil { | ||||
// A read error is expected here if we are closing the connection... | // A read error is expected here if we are closing the connection... | ||||
if !l.isClosing { | |||||
l.closeErr = fmt.Errorf("unable to read LDAP response packet: %s", err) | |||||
if !l.isClosing() { | |||||
l.closeErr.Store(fmt.Errorf("unable to read LDAP response packet: %s", err)) | |||||
l.Debug.Printf("reader error: %s", err.Error()) | l.Debug.Printf("reader error: %s", err.Error()) | ||||
} | } | ||||
return | return | ||||
@@ -334,18 +334,18 @@ func DecodeControl(packet *ber.Packet) Control { | |||||
for _, child := range sequence.Children { | for _, child := range sequence.Children { | ||||
if child.Tag == 0 { | if child.Tag == 0 { | ||||
//Warning | //Warning | ||||
child := child.Children[0] | |||||
packet := ber.DecodePacket(child.Data.Bytes()) | |||||
warningPacket := child.Children[0] | |||||
packet := ber.DecodePacket(warningPacket.Data.Bytes()) | |||||
val, ok := packet.Value.(int64) | val, ok := packet.Value.(int64) | ||||
if ok { | if ok { | ||||
if child.Tag == 0 { | |||||
if warningPacket.Tag == 0 { | |||||
//timeBeforeExpiration | //timeBeforeExpiration | ||||
c.Expire = val | c.Expire = val | ||||
child.Value = c.Expire | |||||
} else if child.Tag == 1 { | |||||
warningPacket.Value = c.Expire | |||||
} else if warningPacket.Tag == 1 { | |||||
//graceAuthNsRemaining | //graceAuthNsRemaining | ||||
c.Grace = val | c.Grace = val | ||||
child.Value = c.Grace | |||||
warningPacket.Value = c.Grace | |||||
} | } | ||||
} | } | ||||
} else if child.Tag == 1 { | } else if child.Tag == 1 { | ||||
@@ -6,7 +6,7 @@ import ( | |||||
"gopkg.in/asn1-ber.v1" | "gopkg.in/asn1-ber.v1" | ||||
) | ) | ||||
// debbuging type | |||||
// debugging type | |||||
// - has a Printf method to write the debug output | // - has a Printf method to write the debug output | ||||
type debugging bool | type debugging bool | ||||
@@ -2,7 +2,7 @@ | |||||
// Use of this source code is governed by a BSD-style | // Use of this source code is governed by a BSD-style | ||||
// license that can be found in the LICENSE file. | // license that can be found in the LICENSE file. | ||||
// | // | ||||
// File contains DN parsing functionallity | |||||
// File contains DN parsing functionality | |||||
// | // | ||||
// https://tools.ietf.org/html/rfc4514 | // https://tools.ietf.org/html/rfc4514 | ||||
// | // | ||||
@@ -52,7 +52,7 @@ import ( | |||||
"fmt" | "fmt" | ||||
"strings" | "strings" | ||||
ber "gopkg.in/asn1-ber.v1" | |||||
"gopkg.in/asn1-ber.v1" | |||||
) | ) | ||||
// AttributeTypeAndValue represents an attributeTypeAndValue from https://tools.ietf.org/html/rfc4514 | // AttributeTypeAndValue represents an attributeTypeAndValue from https://tools.ietf.org/html/rfc4514 | ||||
@@ -83,9 +83,19 @@ func ParseDN(str string) (*DN, error) { | |||||
attribute := new(AttributeTypeAndValue) | attribute := new(AttributeTypeAndValue) | ||||
escaping := false | escaping := false | ||||
unescapedTrailingSpaces := 0 | |||||
stringFromBuffer := func() string { | |||||
s := buffer.String() | |||||
s = s[0 : len(s)-unescapedTrailingSpaces] | |||||
buffer.Reset() | |||||
unescapedTrailingSpaces = 0 | |||||
return s | |||||
} | |||||
for i := 0; i < len(str); i++ { | for i := 0; i < len(str); i++ { | ||||
char := str[i] | char := str[i] | ||||
if escaping { | if escaping { | ||||
unescapedTrailingSpaces = 0 | |||||
escaping = false | escaping = false | ||||
switch char { | switch char { | ||||
case ' ', '"', '#', '+', ',', ';', '<', '=', '>', '\\': | case ' ', '"', '#', '+', ',', ';', '<', '=', '>', '\\': | ||||
@@ -107,10 +117,10 @@ func ParseDN(str string) (*DN, error) { | |||||
buffer.WriteByte(dst[0]) | buffer.WriteByte(dst[0]) | ||||
i++ | i++ | ||||
} else if char == '\\' { | } else if char == '\\' { | ||||
unescapedTrailingSpaces = 0 | |||||
escaping = true | escaping = true | ||||
} else if char == '=' { | } else if char == '=' { | ||||
attribute.Type = buffer.String() | |||||
buffer.Reset() | |||||
attribute.Type = stringFromBuffer() | |||||
// Special case: If the first character in the value is # the | // Special case: If the first character in the value is # the | ||||
// following data is BER encoded so we can just fast forward | // following data is BER encoded so we can just fast forward | ||||
// and decode. | // and decode. | ||||
@@ -133,7 +143,10 @@ func ParseDN(str string) (*DN, error) { | |||||
} | } | ||||
} else if char == ',' || char == '+' { | } else if char == ',' || char == '+' { | ||||
// We're done with this RDN or value, push it | // We're done with this RDN or value, push it | ||||
attribute.Value = buffer.String() | |||||
if len(attribute.Type) == 0 { | |||||
return nil, errors.New("incomplete type, value pair") | |||||
} | |||||
attribute.Value = stringFromBuffer() | |||||
rdn.Attributes = append(rdn.Attributes, attribute) | rdn.Attributes = append(rdn.Attributes, attribute) | ||||
attribute = new(AttributeTypeAndValue) | attribute = new(AttributeTypeAndValue) | ||||
if char == ',' { | if char == ',' { | ||||
@@ -141,8 +154,17 @@ func ParseDN(str string) (*DN, error) { | |||||
rdn = new(RelativeDN) | rdn = new(RelativeDN) | ||||
rdn.Attributes = make([]*AttributeTypeAndValue, 0) | rdn.Attributes = make([]*AttributeTypeAndValue, 0) | ||||
} | } | ||||
buffer.Reset() | |||||
} else if char == ' ' && buffer.Len() == 0 { | |||||
// ignore unescaped leading spaces | |||||
continue | |||||
} else { | } else { | ||||
if char == ' ' { | |||||
// Track unescaped spaces in case they are trailing and we need to remove them | |||||
unescapedTrailingSpaces++ | |||||
} else { | |||||
// Reset if we see a non-space char | |||||
unescapedTrailingSpaces = 0 | |||||
} | |||||
buffer.WriteByte(char) | buffer.WriteByte(char) | ||||
} | } | ||||
} | } | ||||
@@ -150,9 +172,76 @@ func ParseDN(str string) (*DN, error) { | |||||
if len(attribute.Type) == 0 { | if len(attribute.Type) == 0 { | ||||
return nil, errors.New("DN ended with incomplete type, value pair") | return nil, errors.New("DN ended with incomplete type, value pair") | ||||
} | } | ||||
attribute.Value = buffer.String() | |||||
attribute.Value = stringFromBuffer() | |||||
rdn.Attributes = append(rdn.Attributes, attribute) | rdn.Attributes = append(rdn.Attributes, attribute) | ||||
dn.RDNs = append(dn.RDNs, rdn) | dn.RDNs = append(dn.RDNs, rdn) | ||||
} | } | ||||
return dn, nil | return dn, nil | ||||
} | } | ||||
// Equal returns true if the DNs are equal as defined by rfc4517 4.2.15 (distinguishedNameMatch). | |||||
// Returns true if they have the same number of relative distinguished names | |||||
// and corresponding relative distinguished names (by position) are the same. | |||||
func (d *DN) Equal(other *DN) bool { | |||||
if len(d.RDNs) != len(other.RDNs) { | |||||
return false | |||||
} | |||||
for i := range d.RDNs { | |||||
if !d.RDNs[i].Equal(other.RDNs[i]) { | |||||
return false | |||||
} | |||||
} | |||||
return true | |||||
} | |||||
// AncestorOf returns true if the other DN consists of at least one RDN followed by all the RDNs of the current DN. | |||||
// "ou=widgets,o=acme.com" is an ancestor of "ou=sprockets,ou=widgets,o=acme.com" | |||||
// "ou=widgets,o=acme.com" is not an ancestor of "ou=sprockets,ou=widgets,o=foo.com" | |||||
// "ou=widgets,o=acme.com" is not an ancestor of "ou=widgets,o=acme.com" | |||||
func (d *DN) AncestorOf(other *DN) bool { | |||||
if len(d.RDNs) >= len(other.RDNs) { | |||||
return false | |||||
} | |||||
// Take the last `len(d.RDNs)` RDNs from the other DN to compare against | |||||
otherRDNs := other.RDNs[len(other.RDNs)-len(d.RDNs):] | |||||
for i := range d.RDNs { | |||||
if !d.RDNs[i].Equal(otherRDNs[i]) { | |||||
return false | |||||
} | |||||
} | |||||
return true | |||||
} | |||||
// Equal returns true if the RelativeDNs are equal as defined by rfc4517 4.2.15 (distinguishedNameMatch). | |||||
// Relative distinguished names are the same if and only if they have the same number of AttributeTypeAndValues | |||||
// and each attribute of the first RDN is the same as the attribute of the second RDN with the same attribute type. | |||||
// The order of attributes is not significant. | |||||
// Case of attribute types is not significant. | |||||
func (r *RelativeDN) Equal(other *RelativeDN) bool { | |||||
if len(r.Attributes) != len(other.Attributes) { | |||||
return false | |||||
} | |||||
return r.hasAllAttributes(other.Attributes) && other.hasAllAttributes(r.Attributes) | |||||
} | |||||
func (r *RelativeDN) hasAllAttributes(attrs []*AttributeTypeAndValue) bool { | |||||
for _, attr := range attrs { | |||||
found := false | |||||
for _, myattr := range r.Attributes { | |||||
if myattr.Equal(attr) { | |||||
found = true | |||||
break | |||||
} | |||||
} | |||||
if !found { | |||||
return false | |||||
} | |||||
} | |||||
return true | |||||
} | |||||
// Equal returns true if the AttributeTypeAndValue is equivalent to the specified AttributeTypeAndValue | |||||
// Case of the attribute type is not significant | |||||
func (a *AttributeTypeAndValue) Equal(other *AttributeTypeAndValue) bool { | |||||
return strings.EqualFold(a.Type, other.Type) && a.Value == other.Value | |||||
} |
@@ -97,6 +97,13 @@ var LDAPResultCodeMap = map[uint8]string{ | |||||
LDAPResultObjectClassModsProhibited: "Object Class Mods Prohibited", | LDAPResultObjectClassModsProhibited: "Object Class Mods Prohibited", | ||||
LDAPResultAffectsMultipleDSAs: "Affects Multiple DSAs", | LDAPResultAffectsMultipleDSAs: "Affects Multiple DSAs", | ||||
LDAPResultOther: "Other", | LDAPResultOther: "Other", | ||||
ErrorNetwork: "Network Error", | |||||
ErrorFilterCompile: "Filter Compile Error", | |||||
ErrorFilterDecompile: "Filter Decompile Error", | |||||
ErrorDebugging: "Debugging Error", | |||||
ErrorUnexpectedMessage: "Unexpected Message", | |||||
ErrorUnexpectedResponse: "Unexpected Response", | |||||
} | } | ||||
func getLDAPResultCode(packet *ber.Packet) (code uint8, description string) { | func getLDAPResultCode(packet *ber.Packet) (code uint8, description string) { | ||||
@@ -82,7 +82,10 @@ func CompileFilter(filter string) (*ber.Packet, error) { | |||||
if err != nil { | if err != nil { | ||||
return nil, err | return nil, err | ||||
} | } | ||||
if pos != len(filter) { | |||||
switch { | |||||
case pos > len(filter): | |||||
return nil, NewError(ErrorFilterCompile, errors.New("ldap: unexpected end of filter")) | |||||
case pos < len(filter): | |||||
return nil, NewError(ErrorFilterCompile, errors.New("ldap: finished compiling filter with extra at end: "+fmt.Sprint(filter[pos:]))) | return nil, NewError(ErrorFilterCompile, errors.New("ldap: finished compiling filter with extra at end: "+fmt.Sprint(filter[pos:]))) | ||||
} | } | ||||
return packet, nil | return packet, nil | ||||
@@ -9,7 +9,7 @@ import ( | |||||
"io/ioutil" | "io/ioutil" | ||||
"os" | "os" | ||||
ber "gopkg.in/asn1-ber.v1" | |||||
"gopkg.in/asn1-ber.v1" | |||||
) | ) | ||||
// LDAP Application Codes | // LDAP Application Codes | ||||
@@ -153,16 +153,47 @@ func addLDAPDescriptions(packet *ber.Packet) (err error) { | |||||
func addControlDescriptions(packet *ber.Packet) { | func addControlDescriptions(packet *ber.Packet) { | ||||
packet.Description = "Controls" | packet.Description = "Controls" | ||||
for _, child := range packet.Children { | for _, child := range packet.Children { | ||||
var value *ber.Packet | |||||
controlType := "" | |||||
child.Description = "Control" | child.Description = "Control" | ||||
child.Children[0].Description = "Control Type (" + ControlTypeMap[child.Children[0].Value.(string)] + ")" | |||||
value := child.Children[1] | |||||
if len(child.Children) == 3 { | |||||
switch len(child.Children) { | |||||
case 0: | |||||
// at least one child is required for control type | |||||
continue | |||||
case 1: | |||||
// just type, no criticality or value | |||||
controlType = child.Children[0].Value.(string) | |||||
child.Children[0].Description = "Control Type (" + ControlTypeMap[controlType] + ")" | |||||
case 2: | |||||
controlType = child.Children[0].Value.(string) | |||||
child.Children[0].Description = "Control Type (" + ControlTypeMap[controlType] + ")" | |||||
// Children[1] could be criticality or value (both are optional) | |||||
// duck-type on whether this is a boolean | |||||
if _, ok := child.Children[1].Value.(bool); ok { | |||||
child.Children[1].Description = "Criticality" | |||||
} else { | |||||
child.Children[1].Description = "Control Value" | |||||
value = child.Children[1] | |||||
} | |||||
case 3: | |||||
// criticality and value present | |||||
controlType = child.Children[0].Value.(string) | |||||
child.Children[0].Description = "Control Type (" + ControlTypeMap[controlType] + ")" | |||||
child.Children[1].Description = "Criticality" | child.Children[1].Description = "Criticality" | ||||
child.Children[2].Description = "Control Value" | |||||
value = child.Children[2] | value = child.Children[2] | ||||
} | |||||
value.Description = "Control Value" | |||||
switch child.Children[0].Value.(string) { | |||||
default: | |||||
// more than 3 children is invalid | |||||
continue | |||||
} | |||||
if value == nil { | |||||
continue | |||||
} | |||||
switch controlType { | |||||
case ControlTypePaging: | case ControlTypePaging: | ||||
value.Description += " (Paging)" | value.Description += " (Paging)" | ||||
if value.Value != nil { | if value.Value != nil { | ||||
@@ -188,18 +219,18 @@ func addControlDescriptions(packet *ber.Packet) { | |||||
for _, child := range sequence.Children { | for _, child := range sequence.Children { | ||||
if child.Tag == 0 { | if child.Tag == 0 { | ||||
//Warning | //Warning | ||||
child := child.Children[0] | |||||
packet := ber.DecodePacket(child.Data.Bytes()) | |||||
warningPacket := child.Children[0] | |||||
packet := ber.DecodePacket(warningPacket.Data.Bytes()) | |||||
val, ok := packet.Value.(int64) | val, ok := packet.Value.(int64) | ||||
if ok { | if ok { | ||||
if child.Tag == 0 { | |||||
if warningPacket.Tag == 0 { | |||||
//timeBeforeExpiration | //timeBeforeExpiration | ||||
value.Description += " (TimeBeforeExpiration)" | value.Description += " (TimeBeforeExpiration)" | ||||
child.Value = val | |||||
} else if child.Tag == 1 { | |||||
warningPacket.Value = val | |||||
} else if warningPacket.Tag == 1 { | |||||
//graceAuthNsRemaining | //graceAuthNsRemaining | ||||
value.Description += " (GraceAuthNsRemaining)" | value.Description += " (GraceAuthNsRemaining)" | ||||
child.Value = val | |||||
warningPacket.Value = val | |||||
} | } | ||||
} | } | ||||
} else if child.Tag == 1 { | } else if child.Tag == 1 { | ||||
@@ -135,10 +135,10 @@ func (l *Conn) PasswordModify(passwordModifyRequest *PasswordModifyRequest) (*Pa | |||||
extendedResponse := packet.Children[1] | extendedResponse := packet.Children[1] | ||||
for _, child := range extendedResponse.Children { | for _, child := range extendedResponse.Children { | ||||
if child.Tag == 11 { | if child.Tag == 11 { | ||||
passwordModifyReponseValue := ber.DecodePacket(child.Data.Bytes()) | |||||
if len(passwordModifyReponseValue.Children) == 1 { | |||||
if passwordModifyReponseValue.Children[0].Tag == 0 { | |||||
result.GeneratedPassword = ber.DecodeString(passwordModifyReponseValue.Children[0].Data.Bytes()) | |||||
passwordModifyResponseValue := ber.DecodePacket(child.Data.Bytes()) | |||||
if len(passwordModifyResponseValue.Children) == 1 { | |||||
if passwordModifyResponseValue.Children[0].Tag == 0 { | |||||
result.GeneratedPassword = ber.DecodeString(passwordModifyResponseValue.Children[0].Data.Bytes()) | |||||
} | } | ||||
} | } | ||||
} | } | ||||