diff --git a/encode.go b/encode.go index 8487ea0..fcf712a 100644 --- a/encode.go +++ b/encode.go @@ -7,7 +7,7 @@ import ( ) var ( - soapPrefix = "soap" + soapPrefix = "soap" customEnvelopeAttrs map[string]string = nil ) @@ -40,36 +40,86 @@ func (c process) MarshalXML(e *xml.Encoder, _ xml.StartElement) error { tokens.startEnvelope() if c.Client.HeaderParams != nil { tokens.startHeader(c.Client.HeaderName, namespace) - tokens.recursiveEncode(c.Client.HeaderParams) + if err := tokens.recursiveEncode(c.Client.HeaderParams); err != nil { + return err + } tokens.endHeader(c.Client.HeaderName) } - err := tokens.startBody(c.Request.Method, namespace) + err := tokens.startSoapBody(c.Request.Method, namespace) if err != nil { return err } - tokens.recursiveEncode(c.Request.Params) + err = tokens.bodyContents(c, namespace, e) + if err != nil { + return err + } //end envelope - tokens.endBody(c.Request.Method) + tokens.endSoapBody() tokens.endEnvelope() + if err := tokens.flush(e); err != nil { + return err + } + + return e.Flush() +} + +func (tokens *tokenData) bodyContents(c process, namespace string, e *xml.Encoder) error { + isStruct := false + t := reflect.TypeOf(c.Request.Params) + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() == reflect.Struct { + isStruct = true + } + if isStruct { + // Just use encoding/xml directly for structs, which allows for much more + // sophisticated control over the XML structure. The top-level element + // is intrinsically part of the encoding, so we don't need to do that + // for ourselves + + // Flush any pending tokens before we send things directly to the encoder + if err := tokens.flush(e); err != nil { + return err + } + if err := e.Encode(c.Request.Params); err != nil { + return err + } + // if err := tokens.flush(e); err != nil { + // return err + // } + } else { + // For non-structs, we have to explicitly wrap a top-level element around + // the actual data + tokens.startBodyContents(c.Request.Method, namespace) + if err := tokens.recursiveEncode(c.Request.Params); err != nil { + return err + } + tokens.endBodyContents(c.Request.Method) + } + return nil +} + +func (tokens *tokenData) flush(e *xml.Encoder) error { for _, t := range tokens.data { err := e.EncodeToken(t) if err != nil { return err } } - - return e.Flush() + tokens.data = []xml.Token{} + return nil } type tokenData struct { data []xml.Token } -func (tokens *tokenData) recursiveEncode(hm interface{}) { +func (tokens *tokenData) recursiveEncode(hm interface{}) error { v := reflect.ValueOf(hm) switch v.Kind() { @@ -83,12 +133,16 @@ func (tokens *tokenData) recursiveEncode(hm interface{}) { } tokens.data = append(tokens.data, t) - tokens.recursiveEncode(v.MapIndex(key).Interface()) + if err := tokens.recursiveEncode(v.MapIndex(key).Interface()); err != nil { + return err + } tokens.data = append(tokens.data, xml.EndElement{Name: t.Name}) } case reflect.Slice: for i := 0; i < v.Len(); i++ { - tokens.recursiveEncode(v.Index(i).Interface()) + if err := tokens.recursiveEncode(v.Index(i).Interface()); err != nil { + return err + } } case reflect.Array: if v.Len() == 2 { @@ -101,7 +155,9 @@ func (tokens *tokenData) recursiveEncode(hm interface{}) { } tokens.data = append(tokens.data, t) - tokens.recursiveEncode(v.Index(1).Interface()) + if err := tokens.recursiveEncode(v.Index(1).Interface()); err != nil { + return err + } tokens.data = append(tokens.data, xml.EndElement{Name: t.Name}) } case reflect.String: @@ -110,6 +166,7 @@ func (tokens *tokenData) recursiveEncode(hm interface{}) { case reflect.Struct: tokens.data = append(tokens.data, v.Interface()) } + return nil } func (tokens *tokenData) startEnvelope() { @@ -130,7 +187,7 @@ func (tokens *tokenData) startEnvelope() { e.Attr = make([]xml.Attr, 0) for local, value := range customEnvelopeAttrs { e.Attr = append(e.Attr, xml.Attr{ - Name: xml.Name{Space: "", Local: local}, + Name: xml.Name{Space: "", Local: local}, Value: value, }) } @@ -174,8 +231,6 @@ func (tokens *tokenData) startHeader(m, n string) { } tokens.data = append(tokens.data, h, r) - - return } func (tokens *tokenData) endHeader(m string) { @@ -201,18 +256,22 @@ func (tokens *tokenData) endHeader(m string) { tokens.data = append(tokens.data, r, h) } -func (tokens *tokenData) startBody(m, n string) error { +func (tokens *tokenData) startSoapBody(m, n string) error { + if m == "" || n == "" { + return fmt.Errorf("method or namespace is empty") + } + b := xml.StartElement{ Name: xml.Name{ Space: "", Local: fmt.Sprintf("%s:Body", soapPrefix), }, } + tokens.data = append(tokens.data, b) + return nil +} - if m == "" || n == "" { - return fmt.Errorf("method or namespace is empty") - } - +func (tokens *tokenData) startBodyContents(m, n string) { r := xml.StartElement{ Name: xml.Name{ Space: "", @@ -222,27 +281,26 @@ func (tokens *tokenData) startBody(m, n string) error { {Name: xml.Name{Space: "", Local: "xmlns"}, Value: n}, }, } - - tokens.data = append(tokens.data, b, r) - - return nil + tokens.data = append(tokens.data, r) } // endToken close body of the envelope -func (tokens *tokenData) endBody(m string) { - b := xml.EndElement{ +func (tokens *tokenData) endBodyContents(m string) { + r := xml.EndElement{ Name: xml.Name{ Space: "", - Local: fmt.Sprintf("%s:Body", soapPrefix), + Local: m, }, } + tokens.data = append(tokens.data, r) +} - r := xml.EndElement{ +func (tokens *tokenData) endSoapBody() { + b := xml.EndElement{ Name: xml.Name{ Space: "", - Local: m, + Local: fmt.Sprintf("%s:Body", soapPrefix), }, } - - tokens.data = append(tokens.data, r, b) + tokens.data = append(tokens.data, b) } diff --git a/encode_test.go b/encode_test.go index a370e1c..f775f33 100644 --- a/encode_test.go +++ b/encode_test.go @@ -96,7 +96,7 @@ func TestClient_MarshalXML4(t *testing.T) { func TestSetCustomEnvelope(t *testing.T) { SetCustomEnvelope("soapenv", map[string]string{ "xmlns:soapenv": "http://schemas.xmlsoap.org/soap/envelope/", - "xmlns:tem": "http://tempuri.org/", + "xmlns:tem": "http://tempuri.org/", }) soap, err := SoapClient("http://ec.europa.eu/taxation_customs/vies/checkVatService.wsdl", nil) @@ -111,3 +111,88 @@ func TestSetCustomEnvelope(t *testing.T) { } } } + +type checkVatApprox struct { + XMLName xml.Name `xml:"urn:ec.europa.eu:taxud:vies:services:checkVat:types checkVatApprox"` + CountryCode string `xml:"countryCode,omitempty"` + VatNumber string `xml:"vatNumber"` + TraderName string `xml:"traderName,omitempty"` +} +type checkVatApproxResponse struct { + CountryCode string `xml:"countryCode"` + VatNumber string `xml:"vatNumber"` + Valid bool `xml:"valid"` + TraderName string `xml:"traderName,omitempty"` +} + +func (cva *checkVatApprox) SoapBuildRequest() *Request { + r := NewRequest("checkVatApprox", cva) + return r +} + +var encoderParamsTests = []struct { + Desc string + WSDL string + Params *checkVatApprox + Response *checkVatApproxResponse + Err string +}{ + { + Desc: "Fetch a non-existent VAT number", + WSDL: "http://ec.europa.eu/taxation_customs/vies/checkVatService.wsdl", + Params: &checkVatApprox{CountryCode: "fr", VatNumber: "invalid"}, + Response: &checkVatApproxResponse{ + CountryCode: "FR", VatNumber: "invalid", Valid: false, TraderName: "---", + }, + }, + { + Desc: "Fetch a valid VAT number", + WSDL: "http://ec.europa.eu/taxation_customs/vies/checkVatService.wsdl", + Params: &checkVatApprox{CountryCode: "fr", VatNumber: "45327920054"}, + Response: &checkVatApproxResponse{ + CountryCode: "FR", VatNumber: "45327920054", Valid: true, TraderName: "SAS EUROMEDIA", + }, + }, + { + Desc: "Fetch with empty params", + WSDL: "http://ec.europa.eu/taxation_customs/vies/checkVatService.wsdl", + Params: &checkVatApprox{}, + Err: `[soap:Server]: Invalid_input | Detail: `, + }, +} + +func TestClient_MarshalWithEncoder(t *testing.T) { + for _, test := range encoderParamsTests { + soap, err := SoapClient(test.WSDL, nil) + if err != nil { + t.Errorf("%s: error not expected creating client: %s", test.Desc, err) + continue + } + + resp, err := soap.CallByStruct(test.Params) + if err != nil { + t.Errorf("%s: error not expected calling API: %s", test.Desc, err) + continue + } + + var actualResponse checkVatApproxResponse + err = resp.Unmarshal(&actualResponse) + if test.Err != "" { + if err == nil { + t.Errorf("%s: expected error, but got response: %#v", test.Desc, actualResponse) + continue + } else if err.Error() != test.Err { + t.Errorf("%s: error doesn't match expectation: %s", test.Desc, err) + } + } else { + if err != nil { + t.Errorf("%s: unmarshal error not expected: %s", test.Desc, err) + continue + } else if actualResponse != *test.Response { + t.Errorf("%s: response doesn't match expectation: %#v", test.Desc, actualResponse) + continue + } + + } + } +} diff --git a/soap_test.go b/soap_test.go index 90481b2..fdcde35 100644 --- a/soap_test.go +++ b/soap_test.go @@ -185,14 +185,14 @@ func TestClient_Call(t *testing.T) { } c := &Client{} - res, err = c.Call("", Params{}) + _, err = c.Call("", Params{}) if err == nil { t.Errorf("error expected but nothing got.") } c.SetWSDL("://test.") - res, err = c.Call("checkVat", params) + _, err = c.Call("checkVat", params) if err == nil { t.Errorf("invalid WSDL") }