@@ -32,7 +32,6 @@ import (
3232 "google.golang.org/grpc"
3333 "google.golang.org/grpc/codes"
3434 "google.golang.org/grpc/credentials"
35- "google.golang.org/grpc/internal/envconfig"
3635 "google.golang.org/grpc/internal/grpctest"
3736 "google.golang.org/grpc/internal/stubserver"
3837 "google.golang.org/grpc/status"
@@ -411,12 +410,6 @@ func (s) TestTLS_CipherSuitesOverridable(t *testing.T) {
411410// correctly for a server that doesn't specify the NextProtos field and uses
412411// GetConfigForClient to provide the TLS config during the handshake.
413412func (s ) TestTLS_ServerConfiguresALPNByDefault (t * testing.T ) {
414- initialVal := envconfig .EnforceALPNEnabled
415- defer func () {
416- envconfig .EnforceALPNEnabled = initialVal
417- }()
418- envconfig .EnforceALPNEnabled = true
419-
420413 ctx , cancel := context .WithTimeout (context .Background (), defaultTestTimeout )
421414 defer cancel ()
422415
@@ -453,156 +446,104 @@ func (s) TestTLS_ServerConfiguresALPNByDefault(t *testing.T) {
453446// TestTLS_DisabledALPNClient tests the behaviour of TransportCredentials when
454447// connecting to a server that doesn't support ALPN.
455448func (s ) TestTLS_DisabledALPNClient (t * testing.T ) {
456- initialVal := envconfig .EnforceALPNEnabled
457- defer func () {
458- envconfig .EnforceALPNEnabled = initialVal
459- }()
460-
461- tests := []struct {
462- name string
463- alpnEnforced bool
464- wantErr bool
465- }{
466- {
467- name : "enforced" ,
468- alpnEnforced : true ,
469- wantErr : true ,
470- },
471- {
472- name : "not_enforced" ,
473- },
449+ listener , err := tls .Listen ("tcp" , "localhost:0" , & tls.Config {
450+ Certificates : []tls.Certificate {serverCert },
451+ NextProtos : []string {}, // Empty list indicates ALPN is disabled.
452+ })
453+ if err != nil {
454+ t .Fatalf ("Error starting TLS server: %v" , err )
474455 }
475456
476- for _ , tc := range tests {
477- t .Run (tc .name , func (t * testing.T ) {
478- envconfig .EnforceALPNEnabled = tc .alpnEnforced
479-
480- listener , err := tls .Listen ("tcp" , "localhost:0" , & tls.Config {
481- Certificates : []tls.Certificate {serverCert },
482- NextProtos : []string {}, // Empty list indicates ALPN is disabled.
483- })
484- if err != nil {
485- t .Fatalf ("Error starting TLS server: %v" , err )
486- }
487-
488- errCh := make (chan error , 1 )
489- go func () {
490- conn , err := listener .Accept ()
491- if err != nil {
492- errCh <- fmt .Errorf ("listener.Accept returned error: %v" , err )
493- } else {
494- // The first write to the TLS listener initiates the TLS handshake.
495- conn .Write ([]byte ("Hello, World!" ))
496- conn .Close ()
497- }
498- close (errCh )
499- }()
457+ errCh := make (chan error , 1 )
458+ go func () {
459+ conn , err := listener .Accept ()
460+ if err != nil {
461+ errCh <- fmt .Errorf ("listener.Accept returned error: %v" , err )
462+ } else {
463+ // The first write to the TLS listener initiates the TLS handshake.
464+ conn .Write ([]byte ("Hello, World!" ))
465+ conn .Close ()
466+ }
467+ close (errCh )
468+ }()
500469
501- serverAddr := listener .Addr ().String ()
502- conn , err := net .Dial ("tcp" , serverAddr )
503- if err != nil {
504- t .Fatalf ("net.Dial(%s) failed: %v" , serverAddr , err )
505- }
506- defer conn .Close ()
470+ serverAddr := listener .Addr ().String ()
471+ conn , err := net .Dial ("tcp" , serverAddr )
472+ if err != nil {
473+ t .Fatalf ("net.Dial(%s) failed: %v" , serverAddr , err )
474+ }
475+ defer conn .Close ()
507476
508- ctx , cancel := context .WithTimeout (context .Background (), defaultTestTimeout )
509- defer cancel ()
477+ ctx , cancel := context .WithTimeout (context .Background (), defaultTestTimeout )
478+ defer cancel ()
510479
511- clientCfg := tls.Config {
512- ServerName : serverName ,
513- RootCAs : certPool ,
514- NextProtos : []string {"h2" },
515- }
516- _ , _ , err = credentials .NewTLS (& clientCfg ).ClientHandshake (ctx , serverName , conn )
480+ clientCfg := tls.Config {
481+ ServerName : serverName ,
482+ RootCAs : certPool ,
483+ NextProtos : []string {"h2" },
484+ }
485+ _ , _ , err = credentials .NewTLS (& clientCfg ).ClientHandshake (ctx , serverName , conn )
517486
518- if gotErr := (err != nil ); gotErr != tc . wantErr {
519- t .Errorf ("ClientHandshake returned unexpected error: got=%v, want=%t" , err , tc . wantErr )
520- }
487+ if gotErr , wantErr := (err != nil ), true ; gotErr != wantErr {
488+ t .Errorf ("ClientHandshake returned unexpected error: got=%v, want=%t" , err , wantErr )
489+ }
521490
522- select {
523- case err := <- errCh :
524- if err != nil {
525- t .Fatalf ("Unexpected error received from server: %v" , err )
526- }
527- case <- ctx .Done ():
528- t .Fatalf ("Timeout waiting for error from server" )
529- }
530- })
491+ select {
492+ case err := <- errCh :
493+ if err != nil {
494+ t .Fatalf ("Unexpected error received from server: %v" , err )
495+ }
496+ case <- ctx .Done ():
497+ t .Fatalf ("Timeout waiting for error from server" )
531498 }
532499}
533500
534501// TestTLS_DisabledALPNServer tests the behaviour of TransportCredentials when
535502// accepting a request from a client that doesn't support ALPN.
536503func (s ) TestTLS_DisabledALPNServer (t * testing.T ) {
537- initialVal := envconfig .EnforceALPNEnabled
538- defer func () {
539- envconfig .EnforceALPNEnabled = initialVal
540- }()
541-
542- tests := []struct {
543- name string
544- alpnEnforced bool
545- wantErr bool
546- }{
547- {
548- name : "enforced" ,
549- alpnEnforced : true ,
550- wantErr : true ,
551- },
552- {
553- name : "not_enforced" ,
554- },
504+ listener , err := net .Listen ("tcp" , "localhost:0" )
505+ if err != nil {
506+ t .Fatalf ("Error starting server: %v" , err )
555507 }
556508
557- for _ , tc := range tests {
558- t .Run (tc .name , func (t * testing.T ) {
559- envconfig .EnforceALPNEnabled = tc .alpnEnforced
560-
561- listener , err := net .Listen ("tcp" , "localhost:0" )
562- if err != nil {
563- t .Fatalf ("Error starting server: %v" , err )
564- }
565-
566- errCh := make (chan error , 1 )
567- go func () {
568- conn , err := listener .Accept ()
569- if err != nil {
570- errCh <- fmt .Errorf ("listener.Accept returned error: %v" , err )
571- return
572- }
573- defer conn .Close ()
574- serverCfg := tls.Config {
575- Certificates : []tls.Certificate {serverCert },
576- NextProtos : []string {"h2" },
577- }
578- _ , _ , err = credentials .NewTLS (& serverCfg ).ServerHandshake (conn )
579- if gotErr := (err != nil ); gotErr != tc .wantErr {
580- t .Errorf ("ServerHandshake returned unexpected error: got=%v, want=%t" , err , tc .wantErr )
581- }
582- close (errCh )
583- }()
509+ errCh := make (chan error , 1 )
510+ go func () {
511+ conn , err := listener .Accept ()
512+ if err != nil {
513+ errCh <- fmt .Errorf ("listener.Accept returned error: %v" , err )
514+ return
515+ }
516+ defer conn .Close ()
517+ serverCfg := tls.Config {
518+ Certificates : []tls.Certificate {serverCert },
519+ NextProtos : []string {"h2" },
520+ }
521+ _ , _ , err = credentials .NewTLS (& serverCfg ).ServerHandshake (conn )
522+ if gotErr , wantErr := (err != nil ), true ; gotErr != wantErr {
523+ t .Errorf ("ServerHandshake returned unexpected error: got=%v, want=%t" , err , wantErr )
524+ }
525+ close (errCh )
526+ }()
584527
585- serverAddr := listener .Addr ().String ()
586- clientCfg := & tls.Config {
587- Certificates : []tls.Certificate {serverCert },
588- NextProtos : []string {}, // Empty list indicates ALPN is disabled.
589- RootCAs : certPool ,
590- ServerName : serverName ,
591- }
592- conn , err := tls .Dial ("tcp" , serverAddr , clientCfg )
593- if err != nil {
594- t .Fatalf ("tls.Dial(%s) failed: %v" , serverAddr , err )
595- }
596- defer conn .Close ()
597-
598- select {
599- case <- time .After (defaultTestTimeout ):
600- t .Fatal ("Timed out waiting for completion" )
601- case err := <- errCh :
602- if err != nil {
603- t .Fatalf ("Unexpected server error: %v" , err )
604- }
605- }
606- })
528+ serverAddr := listener .Addr ().String ()
529+ clientCfg := & tls.Config {
530+ Certificates : []tls.Certificate {serverCert },
531+ NextProtos : []string {}, // Empty list indicates ALPN is disabled.
532+ RootCAs : certPool ,
533+ ServerName : serverName ,
534+ }
535+ conn , err := tls .Dial ("tcp" , serverAddr , clientCfg )
536+ if err != nil {
537+ t .Fatalf ("tls.Dial(%s) failed: %v" , serverAddr , err )
538+ }
539+ defer conn .Close ()
540+
541+ select {
542+ case <- time .After (defaultTestTimeout ):
543+ t .Fatal ("Timed out waiting for completion" )
544+ case err := <- errCh :
545+ if err != nil {
546+ t .Fatalf ("Unexpected server error: %v" , err )
547+ }
607548 }
608549}
0 commit comments