Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: move tls to a separate file #422

Merged
merged 1 commit into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions sztp-agent/pkg/secureagent/tls.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
SPDX-License-Identifier: Apache-2.0
Copyright (C) 2022-2023 Intel Corporation
Copyright (c) 2022 Dell Inc, or its subsidiaries.
Copyright (C) 2022 Red Hat.
*/

// Package secureagent implements the secure agent
package secureagent

import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"io"
"log"
"net/http"
"os"
"strconv"
"strings"
)

func (a *Agent) doTLSRequest(input string, url string, empty bool) (*BootstrapServerPostOutput, error) {
var postResponse BootstrapServerPostOutput
var errorResponse BootstrapServerErrorOutput

log.Println("[DEBUG] Sending to: " + url)
log.Println("[DEBUG] Sending input: " + input)

body := strings.NewReader(input)
r, err := http.NewRequest(http.MethodPost, url, body)
if err != nil {
return nil, err
}

r.SetBasicAuth(a.GetSerialNumber(), a.GetDevicePassword())
r.Header.Add("Content-Type", a.GetContentTypeReq())

caCert, _ := os.ReadFile(a.GetBootstrapTrustAnchorCert())
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
cert, _ := tls.LoadX509KeyPair(a.GetDeviceEndEntityCert(), a.GetDevicePrivateKey())

client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{ //nolint:gosec
RootCAs: caCertPool,
Certificates: []tls.Certificate{cert},
},
},
}
res, err := client.Do(r)
if err != nil {
log.Println("Error doing the request", err.Error())
return nil, err
}
defer func() {
if err := res.Body.Close(); err != nil {
log.Println("Error when closing:", err)
}
}()

bodyBytes, err := io.ReadAll(res.Body)
if err != nil {
log.Println("Error reading the request", err.Error())
return nil, err
}

decoder := json.NewDecoder(bytes.NewReader(bodyBytes))
decoder.DisallowUnknownFields()
if !empty {
derr := decoder.Decode(&postResponse)
if derr != nil {
errdecoder := json.NewDecoder(bytes.NewReader(bodyBytes))
errdecoder.DisallowUnknownFields()
eerr := errdecoder.Decode(&errorResponse)
if eerr != nil {
log.Println("Received unknown response", string(bodyBytes))
return nil, derr
}
return nil, errors.New("[ERROR] Expected conveyed-information" +
", received error type=" + errorResponse.IetfRestconfErrors.Error[0].ErrorType +
", tag=" + errorResponse.IetfRestconfErrors.Error[0].ErrorTag +
", message=" + errorResponse.IetfRestconfErrors.Error[0].ErrorMessage)
}
log.Println(postResponse)
}
if res.StatusCode != http.StatusOK {
return nil, errors.New("[ERROR] Status code received: " + strconv.Itoa(res.StatusCode) + " ...but status code expected: " + strconv.Itoa(http.StatusOK))
}
return &postResponse, nil
}
53 changes: 53 additions & 0 deletions sztp-agent/pkg/secureagent/tls_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// SPDX-License-Identifier: Apache-2.0
// Copyright (C) 2022-2023 Red Hat.

// Package secureagent implements the secure agent
package secureagent

import (
"reflect"
"testing"
)

func TestAgent_doTLSRequest(t *testing.T) {
type fields struct {
BootstrapURL string
SerialNumber string
DevicePassword string
DevicePrivateKey string
DeviceEndEntityCert string
BootstrapTrustAnchorCert string
ContentTypeReq string
InputJSONContent string
DhcpLeaseFile string
}
var tests []struct {
name string
fields fields
want *BootstrapServerPostOutput
wantErr bool
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Agent{
BootstrapURL: tt.fields.BootstrapURL,
SerialNumber: tt.fields.SerialNumber,
DevicePassword: tt.fields.DevicePassword,
DevicePrivateKey: tt.fields.DevicePrivateKey,
DeviceEndEntityCert: tt.fields.DeviceEndEntityCert,
BootstrapTrustAnchorCert: tt.fields.BootstrapTrustAnchorCert,
ContentTypeReq: tt.fields.ContentTypeReq,
InputJSONContent: tt.fields.InputJSONContent,
DhcpLeaseFile: tt.fields.DhcpLeaseFile,
}
got, err := a.doTLSRequest(a.GetInputJSONContent(), a.GetBootstrapURL(), false)
if (err != nil) != tt.wantErr {
t.Errorf("doTLSRequest() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("doTLSRequest() got = %v, want %v", got, tt.want)
}
})
}
}
78 changes: 0 additions & 78 deletions sztp-agent/pkg/secureagent/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,10 @@ package secureagent

import (
"bufio"
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"io"
"log"
"net/http"
"os"
"regexp"
"strconv"
"strings"

"github.com/jaypipes/ghw"
Expand Down Expand Up @@ -49,77 +42,6 @@ func extractfromLine(line, regex string, index int) string {
return re.FindAllString(line, -1)[index]
}

func (a *Agent) doTLSRequest(input string, url string, empty bool) (*BootstrapServerPostOutput, error) {
var postResponse BootstrapServerPostOutput
var errorResponse BootstrapServerErrorOutput

log.Println("[DEBUG] Sending to: " + url)
log.Println("[DEBUG] Sending input: " + input)

body := strings.NewReader(input)
r, err := http.NewRequest(http.MethodPost, url, body)
if err != nil {
return nil, err
}

r.SetBasicAuth(a.GetSerialNumber(), a.GetDevicePassword())
r.Header.Add("Content-Type", a.GetContentTypeReq())

caCert, _ := os.ReadFile(a.GetBootstrapTrustAnchorCert())
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
cert, _ := tls.LoadX509KeyPair(a.GetDeviceEndEntityCert(), a.GetDevicePrivateKey())

client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{ //nolint:gosec
RootCAs: caCertPool,
Certificates: []tls.Certificate{cert},
},
},
}
res, err := client.Do(r)
if err != nil {
log.Println("Error doing the request", err.Error())
return nil, err
}
defer func() {
if err := res.Body.Close(); err != nil {
log.Println("Error when closing:", err)
}
}()

bodyBytes, err := io.ReadAll(res.Body)
if err != nil {
log.Println("Error reading the request", err.Error())
return nil, err
}

decoder := json.NewDecoder(bytes.NewReader(bodyBytes))
decoder.DisallowUnknownFields()
if !empty {
derr := decoder.Decode(&postResponse)
if derr != nil {
errdecoder := json.NewDecoder(bytes.NewReader(bodyBytes))
errdecoder.DisallowUnknownFields()
eerr := errdecoder.Decode(&errorResponse)
if eerr != nil {
log.Println("Received unknown response", string(bodyBytes))
return nil, derr
}
return nil, errors.New("[ERROR] Expected conveyed-information" +
", received error type=" + errorResponse.IetfRestconfErrors.Error[0].ErrorType +
", tag=" + errorResponse.IetfRestconfErrors.Error[0].ErrorTag +
", message=" + errorResponse.IetfRestconfErrors.Error[0].ErrorMessage)
}
log.Println(postResponse)
}
if res.StatusCode != http.StatusOK {
return nil, errors.New("[ERROR] Status code received: " + strconv.Itoa(res.StatusCode) + " ...but status code expected: " + strconv.Itoa(http.StatusOK))
}
return &postResponse, nil
}

// GetSerialNumber returns the serial number of the device
func GetSerialNumber(givenSerialNumber string) string {
if givenSerialNumber != "" {
Expand Down
44 changes: 0 additions & 44 deletions sztp-agent/pkg/secureagent/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,54 +5,10 @@
package secureagent

import (
"reflect"
"strings"
"testing"
)

func TestAgent_doTLSRequest(t *testing.T) {
type fields struct {
BootstrapURL string
SerialNumber string
DevicePassword string
DevicePrivateKey string
DeviceEndEntityCert string
BootstrapTrustAnchorCert string
ContentTypeReq string
InputJSONContent string
DhcpLeaseFile string
}
var tests []struct {
name string
fields fields
want *BootstrapServerPostOutput
wantErr bool
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Agent{
BootstrapURL: tt.fields.BootstrapURL,
SerialNumber: tt.fields.SerialNumber,
DevicePassword: tt.fields.DevicePassword,
DevicePrivateKey: tt.fields.DevicePrivateKey,
DeviceEndEntityCert: tt.fields.DeviceEndEntityCert,
BootstrapTrustAnchorCert: tt.fields.BootstrapTrustAnchorCert,
ContentTypeReq: tt.fields.ContentTypeReq,
InputJSONContent: tt.fields.InputJSONContent,
DhcpLeaseFile: tt.fields.DhcpLeaseFile,
}
got, err := a.doTLSRequest(a.GetInputJSONContent(), a.GetBootstrapURL(), false)
if (err != nil) != tt.wantErr {
t.Errorf("doTLSRequest() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("doTLSRequest() got = %v, want %v", got, tt.want)
}
})
}
}

func Test_extractfromLine(t *testing.T) {
type args struct {
line string
Expand Down
Loading