diff --git a/join_group_request.go b/join_group_request.go index 3a7ba1712..97e9299ea 100644 --- a/join_group_request.go +++ b/join_group_request.go @@ -25,8 +25,10 @@ func (p *GroupProtocol) encode(pe packetEncoder) (err error) { } type JoinGroupRequest struct { + Version int16 GroupId string SessionTimeout int32 + RebalanceTimeout int32 MemberId string ProtocolType string GroupProtocols map[string][]byte // deprecated; use OrderedGroupProtocols @@ -38,6 +40,9 @@ func (r *JoinGroupRequest) encode(pe packetEncoder) error { return err } pe.putInt32(r.SessionTimeout) + if r.Version >= 1 { + pe.putInt32(r.RebalanceTimeout) + } if err := pe.putString(r.MemberId); err != nil { return err } @@ -76,6 +81,8 @@ func (r *JoinGroupRequest) encode(pe packetEncoder) error { } func (r *JoinGroupRequest) decode(pd packetDecoder, version int16) (err error) { + r.Version = version + if r.GroupId, err = pd.getString(); err != nil { return } @@ -84,6 +91,12 @@ func (r *JoinGroupRequest) decode(pd packetDecoder, version int16) (err error) { return } + if version >= 1 { + if r.RebalanceTimeout, err = pd.getInt32(); err != nil { + return err + } + } + if r.MemberId, err = pd.getString(); err != nil { return } @@ -118,11 +131,18 @@ func (r *JoinGroupRequest) key() int16 { } func (r *JoinGroupRequest) version() int16 { - return 0 + return r.Version } func (r *JoinGroupRequest) requiredVersion() KafkaVersion { - return V0_9_0_0 + switch r.Version { + case 2: + return V0_11_0_0 + case 1: + return V0_10_1_0 + default: + return V0_9_0_0 + } } func (r *JoinGroupRequest) AddGroupProtocol(name string, metadata []byte) { diff --git a/join_group_request_test.go b/join_group_request_test.go index 1ba3308bb..a2e17f980 100644 --- a/join_group_request_test.go +++ b/join_group_request_test.go @@ -3,7 +3,7 @@ package sarama import "testing" var ( - joinGroupRequestNoProtocols = []byte{ + joinGroupRequestV0_NoProtocols = []byte{ 0, 9, 'T', 'e', 's', 't', 'G', 'r', 'o', 'u', 'p', // Group ID 0, 0, 0, 100, // Session timeout 0, 0, // Member ID @@ -11,7 +11,7 @@ var ( 0, 0, 0, 0, // 0 protocol groups } - joinGroupRequestOneProtocol = []byte{ + joinGroupRequestV0_OneProtocol = []byte{ 0, 9, 'T', 'e', 's', 't', 'G', 'r', 'o', 'u', 'p', // Group ID 0, 0, 0, 100, // Session timeout 0, 11, 'O', 'n', 'e', 'P', 'r', 'o', 't', 'o', 'c', 'o', 'l', // Member ID @@ -20,6 +20,17 @@ var ( 0, 3, 'o', 'n', 'e', // Protocol name 0, 0, 0, 3, 0x01, 0x02, 0x03, // protocol metadata } + + joinGroupRequestV1 = []byte{ + 0, 9, 'T', 'e', 's', 't', 'G', 'r', 'o', 'u', 'p', // Group ID + 0, 0, 0, 100, // Session timeout + 0, 0, 0, 200, // Rebalance timeout + 0, 11, 'O', 'n', 'e', 'P', 'r', 'o', 't', 'o', 'c', 'o', 'l', // Member ID + 0, 8, 'c', 'o', 'n', 's', 'u', 'm', 'e', 'r', // Protocol Type + 0, 0, 0, 1, // 1 group protocol + 0, 3, 'o', 'n', 'e', // Protocol name + 0, 0, 0, 3, 0x01, 0x02, 0x03, // protocol metadata + } ) func TestJoinGroupRequest(t *testing.T) { @@ -27,20 +38,20 @@ func TestJoinGroupRequest(t *testing.T) { request.GroupId = "TestGroup" request.SessionTimeout = 100 request.ProtocolType = "consumer" - testRequest(t, "no protocols", request, joinGroupRequestNoProtocols) + testRequest(t, "V0: no protocols", request, joinGroupRequestV0_NoProtocols) } -func TestJoinGroupRequestOneProtocol(t *testing.T) { +func TestJoinGroupRequestV0_OneProtocol(t *testing.T) { request := new(JoinGroupRequest) request.GroupId = "TestGroup" request.SessionTimeout = 100 request.MemberId = "OneProtocol" request.ProtocolType = "consumer" request.AddGroupProtocol("one", []byte{0x01, 0x02, 0x03}) - packet := testRequestEncode(t, "one protocol", request, joinGroupRequestOneProtocol) + packet := testRequestEncode(t, "V0: one protocol", request, joinGroupRequestV0_OneProtocol) request.GroupProtocols = make(map[string][]byte) request.GroupProtocols["one"] = []byte{0x01, 0x02, 0x03} - testRequestDecode(t, "one protocol", request, packet) + testRequestDecode(t, "V0: one protocol", request, packet) } func TestJoinGroupRequestDeprecatedEncode(t *testing.T) { @@ -51,7 +62,22 @@ func TestJoinGroupRequestDeprecatedEncode(t *testing.T) { request.ProtocolType = "consumer" request.GroupProtocols = make(map[string][]byte) request.GroupProtocols["one"] = []byte{0x01, 0x02, 0x03} - packet := testRequestEncode(t, "one protocol", request, joinGroupRequestOneProtocol) + packet := testRequestEncode(t, "V0: one protocol", request, joinGroupRequestV0_OneProtocol) request.AddGroupProtocol("one", []byte{0x01, 0x02, 0x03}) - testRequestDecode(t, "one protocol", request, packet) + testRequestDecode(t, "V0: one protocol", request, packet) +} + +func TestJoinGroupRequestV1(t *testing.T) { + request := new(JoinGroupRequest) + request.Version = 1 + request.GroupId = "TestGroup" + request.SessionTimeout = 100 + request.RebalanceTimeout = 200 + request.MemberId = "OneProtocol" + request.ProtocolType = "consumer" + request.AddGroupProtocol("one", []byte{0x01, 0x02, 0x03}) + packet := testRequestEncode(t, "V1", request, joinGroupRequestV1) + request.GroupProtocols = make(map[string][]byte) + request.GroupProtocols["one"] = []byte{0x01, 0x02, 0x03} + testRequestDecode(t, "V1", request, packet) } diff --git a/join_group_response.go b/join_group_response.go index 6d35fe364..5752acc8a 100644 --- a/join_group_response.go +++ b/join_group_response.go @@ -1,6 +1,8 @@ package sarama type JoinGroupResponse struct { + Version int16 + ThrottleTime int32 Err KError GenerationId int32 GroupProtocol string @@ -22,6 +24,9 @@ func (r *JoinGroupResponse) GetMembers() (map[string]ConsumerGroupMemberMetadata } func (r *JoinGroupResponse) encode(pe packetEncoder) error { + if r.Version >= 2 { + pe.putInt32(r.ThrottleTime) + } pe.putInt16(int16(r.Err)) pe.putInt32(r.GenerationId) @@ -53,6 +58,14 @@ func (r *JoinGroupResponse) encode(pe packetEncoder) error { } func (r *JoinGroupResponse) decode(pd packetDecoder, version int16) (err error) { + r.Version = version + + if version >= 2 { + if r.ThrottleTime, err = pd.getInt32(); err != nil { + return + } + } + kerr, err := pd.getInt16() if err != nil { return err @@ -107,9 +120,16 @@ func (r *JoinGroupResponse) key() int16 { } func (r *JoinGroupResponse) version() int16 { - return 0 + return r.Version } func (r *JoinGroupResponse) requiredVersion() KafkaVersion { - return V0_9_0_0 + switch r.Version { + case 2: + return V0_11_0_0 + case 1: + return V0_10_1_0 + default: + return V0_9_0_0 + } } diff --git a/join_group_response_test.go b/join_group_response_test.go index ba7f71f20..a43b37a95 100644 --- a/join_group_response_test.go +++ b/join_group_response_test.go @@ -6,7 +6,7 @@ import ( ) var ( - joinGroupResponseNoError = []byte{ + joinGroupResponseV0_NoError = []byte{ 0x00, 0x00, // No error 0x00, 0x01, 0x02, 0x03, // Generation ID 0, 8, 'p', 'r', 'o', 't', 'o', 'c', 'o', 'l', // Protocol name chosen @@ -15,7 +15,7 @@ var ( 0, 0, 0, 0, // No member info } - joinGroupResponseWithError = []byte{ + joinGroupResponseV0_WithError = []byte{ 0, 23, // Error: inconsistent group protocol 0x00, 0x00, 0x00, 0x00, // Generation ID 0, 0, // Protocol name chosen @@ -24,7 +24,7 @@ var ( 0, 0, 0, 0, // No member info } - joinGroupResponseLeader = []byte{ + joinGroupResponseV0_Leader = []byte{ 0x00, 0x00, // No error 0x00, 0x01, 0x02, 0x03, // Generation ID 0, 8, 'p', 'r', 'o', 't', 'o', 'c', 'o', 'l', // Protocol name chosen @@ -34,13 +34,32 @@ var ( 0, 3, 'f', 'o', 'o', // Member ID 0, 0, 0, 3, 0x01, 0x02, 0x03, // Member metadata } + + joinGroupResponseV1 = []byte{ + 0x00, 0x00, // No error + 0x00, 0x01, 0x02, 0x03, // Generation ID + 0, 8, 'p', 'r', 'o', 't', 'o', 'c', 'o', 'l', // Protocol name chosen + 0, 3, 'f', 'o', 'o', // Leader ID + 0, 3, 'b', 'a', 'r', // Member ID + 0, 0, 0, 0, // No member info + } + + joinGroupResponseV2 = []byte{ + 0, 0, 0, 100, + 0x00, 0x00, // No error + 0x00, 0x01, 0x02, 0x03, // Generation ID + 0, 8, 'p', 'r', 'o', 't', 'o', 'c', 'o', 'l', // Protocol name chosen + 0, 3, 'f', 'o', 'o', // Leader ID + 0, 3, 'b', 'a', 'r', // Member ID + 0, 0, 0, 0, // No member info + } ) -func TestJoinGroupResponse(t *testing.T) { +func TestJoinGroupResponseV0(t *testing.T) { var response *JoinGroupResponse response = new(JoinGroupResponse) - testVersionDecodable(t, "no error", response, joinGroupResponseNoError, 0) + testVersionDecodable(t, "no error", response, joinGroupResponseV0_NoError, 0) if response.Err != ErrNoError { t.Error("Decoding Err failed: no error expected but found", response.Err) } @@ -58,7 +77,7 @@ func TestJoinGroupResponse(t *testing.T) { } response = new(JoinGroupResponse) - testVersionDecodable(t, "with error", response, joinGroupResponseWithError, 0) + testVersionDecodable(t, "with error", response, joinGroupResponseV0_WithError, 0) if response.Err != ErrInconsistentGroupProtocol { t.Error("Decoding Err failed: ErrInconsistentGroupProtocol expected but found", response.Err) } @@ -76,7 +95,7 @@ func TestJoinGroupResponse(t *testing.T) { } response = new(JoinGroupResponse) - testVersionDecodable(t, "with error", response, joinGroupResponseLeader, 0) + testVersionDecodable(t, "with error", response, joinGroupResponseV0_Leader, 0) if response.Err != ErrNoError { t.Error("Decoding Err failed: ErrNoError expected but found", response.Err) } @@ -96,3 +115,58 @@ func TestJoinGroupResponse(t *testing.T) { t.Error("Decoding foo member failed, found:", response.Members["foo"]) } } + +func TestJoinGroupResponseV1(t *testing.T) { + response := new(JoinGroupResponse) + testVersionDecodable(t, "no error", response, joinGroupResponseV1, 1) + if response.Err != ErrNoError { + t.Error("Decoding Err failed: no error expected but found", response.Err) + } + if response.GenerationId != 66051 { + t.Error("Decoding GenerationId failed, found:", response.GenerationId) + } + if response.GroupProtocol != "protocol" { + t.Error("Decoding GroupProtocol failed, found:", response.GroupProtocol) + } + if response.LeaderId != "foo" { + t.Error("Decoding LeaderId failed, found:", response.LeaderId) + } + if response.MemberId != "bar" { + t.Error("Decoding MemberId failed, found:", response.MemberId) + } + if response.Version != 1 { + t.Error("Decoding Version failed, found:", response.Version) + } + if len(response.Members) != 0 { + t.Error("Decoding Members failed, found:", response.Members) + } +} + +func TestJoinGroupResponseV2(t *testing.T) { + response := new(JoinGroupResponse) + testVersionDecodable(t, "no error", response, joinGroupResponseV2, 2) + if response.ThrottleTime != 100 { + t.Error("Decoding ThrottleTime failed, found:", response.ThrottleTime) + } + if response.Err != ErrNoError { + t.Error("Decoding Err failed: no error expected but found", response.Err) + } + if response.GenerationId != 66051 { + t.Error("Decoding GenerationId failed, found:", response.GenerationId) + } + if response.GroupProtocol != "protocol" { + t.Error("Decoding GroupProtocol failed, found:", response.GroupProtocol) + } + if response.LeaderId != "foo" { + t.Error("Decoding LeaderId failed, found:", response.LeaderId) + } + if response.MemberId != "bar" { + t.Error("Decoding MemberId failed, found:", response.MemberId) + } + if response.Version != 2 { + t.Error("Decoding Version failed, found:", response.Version) + } + if len(response.Members) != 0 { + t.Error("Decoding Members failed, found:", response.Members) + } +}