Skip to content

Commit

Permalink
fix: PSK failing if config session cache set
Browse files Browse the repository at this point in the history
* Fix a bug causing PSK to fail if Config.ClientSessionCache is set.
* Removed `ClientSessionCacheOverride` from `UtlsPreSharedKeyExtension`. Set the `ClientSessionCache` in `Config`!

Co-Authored-By: zeeker999 <[email protected]>
  • Loading branch information
gaukas and zeeker999 committed Aug 17, 2023
1 parent 3d7eea3 commit 3162534
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 124 deletions.
8 changes: 3 additions & 5 deletions examples/tls-psk/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,9 @@ func main() {
}

tlsConnPSK := tls.UClient(tcpConnPSK, &tls.Config{
ServerName: strings.Split(serverAddr, ":")[0],
// ClientSessionCache: csc, // set this will cause PSK to fail. This is a bug...
}, tls.HelloChrome_100_PSK, &tls.UtlsPreSharedKeyExtension{
ClientSessionCacheOverride: csc, // ONLY set your own ClientSessionCache here if you want to use PSK
})
ServerName: strings.Split(serverAddr, ":")[0],
ClientSessionCache: csc,
}, tls.HelloChrome_100_PSK, &tls.UtlsPreSharedKeyExtension{})

// HS
err = tlsConnPSK.Handshake()
Expand Down
10 changes: 5 additions & 5 deletions u_clienthello_json.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ func (c *CompressionMethodsJSONUnmarshaler) CompressionMethods() []uint8 {
}

type TLSExtensionsJSONUnmarshaler struct {
AllowUnknownExt bool // if set, unknown extensions will be added as GenericExtension, without recovering ext payload
ClientSessionCache ClientSessionCache // if set, PSK extension will be Unmarshaled into UtlsPreSharedKeyExtension. Otherwise FakePreSharedKeyExtension.
extensions []TLSExtensionJSON
AllowUnknownExt bool // if set, unknown extensions will be added as GenericExtension, without recovering ext payload
UseRealPSK bool // if set, PSK extension will be real PSK extension, otherwise it will be fake PSK extension
extensions []TLSExtensionJSON
}

func (e *TLSExtensionsJSONUnmarshaler) UnmarshalJSON(jsonStr []byte) error {
Expand Down Expand Up @@ -120,8 +120,8 @@ func (e *TLSExtensionsJSONUnmarshaler) UnmarshalJSON(jsonStr []byte) error {
switch extID {
case extensionPreSharedKey:
// PSK extension, need to see if we do real or fake PSK
if e.ClientSessionCache != nil {
ext = &UtlsPreSharedKeyExtension{ClientSessionCacheOverride: e.ClientSessionCache}
if e.UseRealPSK {
ext = &UtlsPreSharedKeyExtension{}
} else {
ext = &FakePreSharedKeyExtension{}
}
Expand Down
100 changes: 12 additions & 88 deletions u_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ func (chs *ClientHelloSpec) ReadCompressionMethods(compressionMethods []byte) er
// a byte slice into []TLSExtension.
//
// If keepPSK is not set, the PSK extension will cause an error.
func (chs *ClientHelloSpec) ReadTLSExtensions(b []byte, allowBluntMimicry bool, clientSessionCache ...ClientSessionCache) error {
func (chs *ClientHelloSpec) ReadTLSExtensions(b []byte, allowBluntMimicry bool, realPSK bool) error {
extensions := cryptobyte.String(b)
for !extensions.Empty() {
var extension uint16
Expand All @@ -228,8 +228,8 @@ func (chs *ClientHelloSpec) ReadTLSExtensions(b []byte, allowBluntMimicry bool,
switch extension {
case extensionPreSharedKey:
// PSK extension, need to see if we do real or fake PSK
if len(clientSessionCache) > 0 && clientSessionCache[0] != nil {
extWriter = &UtlsPreSharedKeyExtension{ClientSessionCacheOverride: clientSessionCache[0]}
if realPSK {
extWriter = &UtlsPreSharedKeyExtension{}
} else {
extWriter = &FakePreSharedKeyExtension{}
}
Expand Down Expand Up @@ -464,96 +464,20 @@ func (chs *ClientHelloSpec) ImportTLSClientHelloFromJSON(jsonB []byte) error {
}

// FromRaw converts a ClientHello message in the form of raw bytes into a ClientHelloSpec.
func (chs *ClientHelloSpec) FromRaw(raw []byte, allowBluntMimicry ...bool) error {
//
// ctrlFlags: []bool{bluntMimicry, realPSK}
func (chs *ClientHelloSpec) FromRaw(raw []byte, ctrlFlags ...bool) error {
if chs == nil {
return errors.New("cannot unmarshal into nil ClientHelloSpec")
}

var bluntMimicry = false
if len(allowBluntMimicry) == 1 {
bluntMimicry = allowBluntMimicry[0]
}

*chs = ClientHelloSpec{} // reset
s := cryptobyte.String(raw)

var contentType uint8
var recordVersion uint16
if !s.ReadUint8(&contentType) || // record type
!s.ReadUint16(&recordVersion) || !s.Skip(2) { // record version and length
return errors.New("unable to read record type, version, and length")
}

if recordType(contentType) != recordTypeHandshake {
return errors.New("record is not a handshake")
}

var handshakeVersion uint16
var handshakeType uint8

if !s.ReadUint8(&handshakeType) || !s.Skip(3) || // message type and 3 byte length
!s.ReadUint16(&handshakeVersion) || !s.Skip(32) { // 32 byte random
return errors.New("unable to read handshake message type, length, and random")
}

if handshakeType != typeClientHello {
return errors.New("handshake message is not a ClientHello")
}

chs.TLSVersMin = recordVersion
chs.TLSVersMax = handshakeVersion

var ignoredSessionID cryptobyte.String
if !s.ReadUint8LengthPrefixed(&ignoredSessionID) {
return errors.New("unable to read session id")
}

// CipherSuites
var cipherSuitesBytes cryptobyte.String
if !s.ReadUint16LengthPrefixed(&cipherSuitesBytes) {
return errors.New("unable to read ciphersuites")
}

if err := chs.ReadCipherSuites(cipherSuitesBytes); err != nil {
return err
}

// CompressionMethods
var compressionMethods cryptobyte.String
if !s.ReadUint8LengthPrefixed(&compressionMethods) {
return errors.New("unable to read compression methods")
}

if err := chs.ReadCompressionMethods(compressionMethods); err != nil {
return err
var realPSK = false
if len(ctrlFlags) > 0 {
bluntMimicry = ctrlFlags[0]
}

if s.Empty() {
// Extensions are optional
return nil
}

var extensions cryptobyte.String
if !s.ReadUint16LengthPrefixed(&extensions) {
return errors.New("unable to read extensions data")
}

if err := chs.ReadTLSExtensions(extensions, bluntMimicry); err != nil {
return err
}

return nil
}

// FromRaw converts a ClientHello message in the form of raw bytes into a ClientHelloSpec.
func (chs *ClientHelloSpec) FromRawWithClientSessionCache(raw []byte, csc ClientSessionCache, allowBluntMimicry ...bool) error {
if chs == nil {
return errors.New("cannot unmarshal into nil ClientHelloSpec")
}

var bluntMimicry = false
if len(allowBluntMimicry) == 1 {
bluntMimicry = allowBluntMimicry[0]
if len(ctrlFlags) > 1 {
realPSK = ctrlFlags[1]
}

*chs = ClientHelloSpec{} // reset
Expand Down Expand Up @@ -620,7 +544,7 @@ func (chs *ClientHelloSpec) FromRawWithClientSessionCache(raw []byte, csc Client
return errors.New("unable to read extensions data")
}

if err := chs.ReadTLSExtensions(extensions, bluntMimicry, csc); err != nil {
if err := chs.ReadTLSExtensions(extensions, bluntMimicry, realPSK); err != nil {
return err
}

Expand Down
8 changes: 2 additions & 6 deletions u_fingerprinter.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type Fingerprinter struct {
// (including things like different SNI lengths) would cause padding to be necessary
AlwaysAddPadding bool

ClientSessionCache ClientSessionCache // if set, PSK extension will be made into UtlsPreSharedKeyExtension. Otherwise FakePreSharedKeyExtension.
RealPSKResumption bool // if set, PSK extension (if any) will be real PSK extension, otherwise it will be fake PSK extension
}

// FingerprintClientHello returns a ClientHelloSpec which is based on the
Expand All @@ -46,11 +46,7 @@ func (f *Fingerprinter) FingerprintClientHello(data []byte) (clientHelloSpec *Cl
func (f *Fingerprinter) RawClientHello(raw []byte) (clientHelloSpec *ClientHelloSpec, err error) {
clientHelloSpec = &ClientHelloSpec{}

if f.ClientSessionCache != nil {
err = clientHelloSpec.FromRawWithClientSessionCache(raw, f.ClientSessionCache, f.AllowBluntMimicry)
} else {
err = clientHelloSpec.FromRaw(raw, f.AllowBluntMimicry)
}
err = clientHelloSpec.FromRaw(raw, f.AllowBluntMimicry, f.RealPSKResumption)
if err != nil {
return nil, err
}
Expand Down
5 changes: 4 additions & 1 deletion u_parrots.go
Original file line number Diff line number Diff line change
Expand Up @@ -2453,7 +2453,10 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error {
if cs != nil {
session = cs.session
}
// TLS 1.3 (PSK) resumption is handled by PreSharedKeyExtension in MarshalClientHello()
}
// TLS 1.3 (PSK) resumption is handled by PreSharedKeyExtension in MarshalClientHello()
if session != nil && session.version == VersionTLS13 {
break
}
err := uconn.SetSessionState(cs)
if err != nil {
Expand Down
19 changes: 0 additions & 19 deletions u_pre_shared_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,6 @@ func (*UnimplementedPreSharedKeyExtension) ReadWithRawHello(raw, b []byte) (int,
type UtlsPreSharedKeyExtension struct {
UnimplementedPreSharedKeyExtension

// ClientSessionCacheOverride is used to specify the ClientSessionCache to be used
// for PSK-resumption.
//
// bug: tls.Config.ClientSessionCache must be nil for PSK-resumption to work, even though
// it is supposed to be overridden by ClientSessionCacheOverride.
ClientSessionCacheOverride ClientSessionCache

identities []pskIdentity
binders [][]byte
binderKey []byte // this will be used to compute the binder when hello message is ready
Expand Down Expand Up @@ -172,12 +165,6 @@ func (e *UtlsPreSharedKeyExtension) ReadWithRawHello(raw, b []byte) (int, error)
}

func (e *UtlsPreSharedKeyExtension) preloadSession(uc *UConn) error {
// var sessionCache ClientSessionCache
// must set either e.Session or uc.config.ClientSessionCache
if e.ClientSessionCacheOverride != nil {
uc.config.ClientSessionCache = e.ClientSessionCacheOverride
}

// load Hello
hello := uc.HandshakeState.Hello.getPrivatePtr()
// try to use loadSession()
Expand All @@ -203,16 +190,10 @@ func (e *UtlsPreSharedKeyExtension) preloadSession(uc *UConn) error {
}

func (e *UtlsPreSharedKeyExtension) Write(b []byte) (int, error) {
if e.ClientSessionCacheOverride == nil {
return 0, errors.New("tls: ClientSessionCache must be set to use UtlsPreSharedKeyExtension")
}
return len(b), nil // ignore the data
}

func (e *UtlsPreSharedKeyExtension) UnmarshalJSON(_ []byte) error {
if e.ClientSessionCacheOverride == nil {
return errors.New("tls: ClientSessionCache must be set to use UtlsPreSharedKeyExtension")
}
return nil // ignore the data
}

Expand Down

0 comments on commit 3162534

Please sign in to comment.