WARNING: THIS SITE IS A MIRROR OF GITHUB.COM / IT CANNOT LOGIN OR REGISTER ACCOUNTS / THE CONTENTS ARE PROVIDED AS-IS / THIS SITE ASSUMES NO RESPONSIBILITY FOR ANY DISPLAYED CONTENT OR LINKS / IF YOU FOUND SOMETHING MAY NOT GOOD FOR EVERYONE, CONTACT ADMIN AT ilovescratch@foxmail.com
Skip to content

Commit 7ccfb4a

Browse files
committed
Simplify types in message package with generics
1 parent 7df0fd4 commit 7ccfb4a

File tree

3 files changed

+100
-119
lines changed

3 files changed

+100
-119
lines changed

client.go

Lines changed: 26 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ func (c *Client) Enroll(ctx context.Context, logger logrus.FieldLogger, code str
166166
}
167167

168168
// Decode the response
169-
r := message.EnrollResponse{}
169+
r := message.APIResponse[message.EnrollResponseData]{}
170170
b, err := io.ReadAll(resp.Body)
171171
if err != nil {
172172
return nil, nil, nil, nil, &APIError{e: fmt.Errorf("error reading response body: %s", err), ReqID: reqID}
@@ -637,32 +637,23 @@ func (c *Client) EndpointPreAuth(ctx context.Context) (*message.PreAuthData, err
637637
defer resp.Body.Close()
638638

639639
reqID := resp.Header.Get("X-Request-ID")
640-
respBody, err := io.ReadAll(resp.Body)
640+
641+
r := message.APIResponse[message.PreAuthData]{}
642+
b, err := io.ReadAll(resp.Body)
641643
if err != nil {
642-
return nil, &APIError{e: fmt.Errorf("failed to read the response body: %s", err), ReqID: reqID}
644+
return nil, &APIError{e: fmt.Errorf("error reading response body: %s", err), ReqID: reqID}
643645
}
644646

645-
switch resp.StatusCode {
646-
case http.StatusOK:
647-
r := message.PreAuthResponse{}
648-
if err = json.Unmarshal(respBody, &r); err != nil {
649-
return nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, respBody), ReqID: reqID}
650-
}
651-
652-
if r.Data.PollToken == "" || r.Data.LoginURL == "" {
653-
return nil, &APIError{e: fmt.Errorf("missing pollToken or loginURL"), ReqID: reqID}
654-
}
647+
if err := json.Unmarshal(b, &r); err != nil {
648+
return nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, b), ReqID: reqID}
649+
}
655650

656-
return &r.Data, nil
657-
default:
658-
var errors struct {
659-
Errors message.APIErrors
660-
}
661-
if err := json.Unmarshal(respBody, &errors); err != nil {
662-
return nil, fmt.Errorf("bad status code '%d', body: %s", resp.StatusCode, respBody)
663-
}
664-
return nil, &APIError{e: errors.Errors.ToError(), ReqID: reqID}
651+
// Check for any errors returned by the API
652+
if err := r.Errors.ToError(); err != nil {
653+
return nil, &APIError{e: err, ReqID: reqID}
665654
}
655+
656+
return &r.Data, nil
666657
}
667658

668659
func (c *Client) EndpointAuthPoll(ctx context.Context, pollCode string) (*message.EndpointAuthPollData, error) {
@@ -684,25 +675,21 @@ func (c *Client) EndpointAuthPoll(ctx context.Context, pollCode string) (*messag
684675
defer resp.Body.Close()
685676

686677
reqID := resp.Header.Get("X-Request-ID")
687-
respBody, err := io.ReadAll(resp.Body)
678+
679+
r := message.APIResponse[message.EndpointAuthPollData]{}
680+
b, err := io.ReadAll(resp.Body)
688681
if err != nil {
689-
return nil, &APIError{e: fmt.Errorf("failed to read the response body: %s", err), ReqID: reqID}
682+
return nil, &APIError{e: fmt.Errorf("error reading response body: %s", err), ReqID: reqID}
690683
}
691684

692-
switch resp.StatusCode {
693-
case http.StatusOK:
694-
r := message.EndpointAuthPollResponse{}
695-
if err = json.Unmarshal(respBody, &r); err != nil {
696-
return nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, respBody), ReqID: reqID}
697-
}
698-
return &r.Data, nil
699-
default:
700-
var errors struct {
701-
Errors message.APIErrors
702-
}
703-
if err := json.Unmarshal(respBody, &errors); err != nil {
704-
return nil, fmt.Errorf("bad status code '%d', body: %s", resp.StatusCode, respBody)
705-
}
706-
return nil, &APIError{e: errors.Errors.ToError(), ReqID: reqID}
685+
if err := json.Unmarshal(b, &r); err != nil {
686+
return nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, b), ReqID: reqID}
707687
}
688+
689+
// Check for any errors returned by the API
690+
if err := r.Errors.ToError(); err != nil {
691+
return nil, &APIError{e: err, ReqID: reqID}
692+
}
693+
694+
return &r.Data, nil
708695
}

client_test.go

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,15 @@ func TestEnroll(t *testing.T) {
6363
"test": m{"code": req.Code, "dhPubkey": req.NebulaPubkeyX25519},
6464
})
6565
if err != nil {
66-
return jsonMarshal(message.EnrollResponse{
66+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
6767
Errors: message.APIErrors{{
6868
Code: "ERR_FAILED_TO_MARSHAL_YAML",
6969
Message: "failed to marshal test response config",
7070
}},
7171
})
7272
}
7373

74-
return jsonMarshal(message.EnrollResponse{
74+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
7575
Data: message.EnrollResponseData{
7676
HostID: hostID,
7777
Counter: counter,
@@ -143,7 +143,7 @@ func TestEnroll(t *testing.T) {
143143
// Test error handling
144144
errorMsg := "invalid enrollment code"
145145
ts.ExpectEnrollment(code, message.NetworkCurve25519, func(req message.EnrollRequest) []byte {
146-
return jsonMarshal(message.EnrollResponse{
146+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
147147
Errors: message.APIErrors{{
148148
Code: "ERR_INVALID_ENROLLMENT_CODE",
149149
Message: errorMsg,
@@ -188,15 +188,15 @@ func TestDoUpdate(t *testing.T) {
188188
"test": m{"code": req.Code, "dhPubkey": req.NebulaPubkeyX25519},
189189
})
190190
if err != nil {
191-
return jsonMarshal(message.EnrollResponse{
191+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
192192
Errors: message.APIErrors{{
193193
Code: "ERR_FAILED_TO_MARSHAL_YAML",
194194
Message: "failed to marshal test response config",
195195
}},
196196
})
197197
}
198198

199-
return jsonMarshal(message.EnrollResponse{
199+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
200200
Data: message.EnrollResponseData{
201201
HostID: "foobar",
202202
Counter: 1,
@@ -452,15 +452,15 @@ func TestDoUpdate_P256(t *testing.T) {
452452
"test": m{"code": req.Code, "p256Pubkey": req.NebulaPubkeyP256},
453453
})
454454
if err != nil {
455-
return jsonMarshal(message.EnrollResponse{
455+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
456456
Errors: message.APIErrors{{
457457
Code: "ERR_FAILED_TO_MARSHAL_YAML",
458458
Message: "failed to marshal test response config",
459459
}},
460460
})
461461
}
462462

463-
return jsonMarshal(message.EnrollResponse{
463+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
464464
Data: message.EnrollResponseData{
465465
HostID: "foobar",
466466
Counter: 1,
@@ -546,7 +546,7 @@ func TestDoUpdate_P256(t *testing.T) {
546546

547547
sig, err := nk.HostP256PrivateKey.Sign(rawRes)
548548
if err != nil {
549-
return jsonMarshal(message.EnrollResponse{
549+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
550550
Errors: message.APIErrors{{
551551
Code: "ERR_FAILED_TO_SIGN_MESSAGE",
552552
Message: "failed to sign message",
@@ -590,7 +590,7 @@ func TestDoUpdate_P256(t *testing.T) {
590590
hashed := sha256.Sum256(rawRes)
591591
sig, err := ecdsa.SignASN1(rand.Reader, caPrivkey, hashed[:])
592592
if err != nil {
593-
return jsonMarshal(message.EnrollResponse{
593+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
594594
Errors: message.APIErrors{{
595595
Code: "ERR_FAILED_TO_SIGN_MESSAGE",
596596
Message: "failed to sign message",
@@ -644,7 +644,7 @@ func TestDoUpdate_P256(t *testing.T) {
644644
hashed := sha256.Sum256(rawRes)
645645
sig, err := ecdsa.SignASN1(rand.Reader, caPrivkey, hashed[:])
646646
if err != nil {
647-
return jsonMarshal(message.EnrollResponse{
647+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
648648
Errors: message.APIErrors{{
649649
Code: "ERR_FAILED_TO_SIGN_MESSAGE",
650650
Message: "failed to sign message",
@@ -692,15 +692,15 @@ func TestCommandResponse(t *testing.T) {
692692
"test": m{"code": req.Code, "dhPubkey": req.NebulaPubkeyX25519},
693693
})
694694
if err != nil {
695-
return jsonMarshal(message.EnrollResponse{
695+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
696696
Errors: message.APIErrors{{
697697
Code: "ERR_FAILED_TO_MARSHAL_YAML",
698698
Message: "failed to marshal test response config",
699699
}},
700700
})
701701
}
702702

703-
return jsonMarshal(message.EnrollResponse{
703+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
704704
Data: message.EnrollResponseData{
705705
HostID: "foobar",
706706
Counter: 1,
@@ -760,7 +760,7 @@ func TestCommandResponse(t *testing.T) {
760760
// Test error handling
761761
errorMsg := "sample error"
762762
ts.ExpectDNClientRequest(message.CommandResponse, http.StatusBadRequest, func(r message.RequestWrapper) []byte {
763-
return jsonMarshal(message.EnrollResponse{
763+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
764764
Errors: message.APIErrors{{
765765
Code: "ERR_INVALID_VALUE",
766766
Message: errorMsg,
@@ -794,15 +794,15 @@ func TestStreamCommandResponse(t *testing.T) {
794794
"test": m{"code": req.Code, "dhPubkey": req.NebulaPubkeyX25519},
795795
})
796796
if err != nil {
797-
return jsonMarshal(message.EnrollResponse{
797+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
798798
Errors: message.APIErrors{{
799799
Code: "ERR_FAILED_TO_MARSHAL_YAML",
800800
Message: "failed to marshal test response config",
801801
}},
802802
})
803803
}
804804

805-
return jsonMarshal(message.EnrollResponse{
805+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
806806
Data: message.EnrollResponseData{
807807
HostID: "foobar",
808808
Counter: 1,
@@ -871,7 +871,7 @@ func TestStreamCommandResponse(t *testing.T) {
871871
// Test error handling
872872
errorMsg := "sample error"
873873
ts.ExpectStreamingRequest(message.CommandResponse, http.StatusBadRequest, func(r message.RequestWrapper) []byte {
874-
return jsonMarshal(message.EnrollResponse{
874+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
875875
Errors: message.APIErrors{{
876876
Code: "ERR_INVALID_VALUE",
877877
Message: errorMsg,
@@ -920,15 +920,15 @@ func TestReauthenticate(t *testing.T) {
920920
"test": m{"code": req.Code, "dhPubkey": req.NebulaPubkeyX25519},
921921
})
922922
if err != nil {
923-
return jsonMarshal(message.EnrollResponse{
923+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
924924
Errors: message.APIErrors{{
925925
Code: "ERR_FAILED_TO_MARSHAL_YAML",
926926
Message: "failed to marshal test response config",
927927
}},
928928
})
929929
}
930930

931-
return jsonMarshal(message.EnrollResponse{
931+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
932932
Data: message.EnrollResponseData{
933933
HostID: "foobar",
934934
Counter: 1,
@@ -1065,7 +1065,7 @@ func TestGetOidcPollCode(t *testing.T) {
10651065
t.Cleanup(func() { ts.Close() })
10661066
const expectedCode = "123456"
10671067
ts.ExpectAPIRequest(http.StatusOK, func(req any) []byte {
1068-
return jsonMarshal(message.PreAuthResponse{Data: message.PreAuthData{PollToken: expectedCode, LoginURL: "https://example.com"}})
1068+
return jsonMarshal(message.APIResponse[message.PreAuthData]{Data: message.PreAuthData{PollToken: expectedCode, LoginURL: "https://example.com"}})
10691069
})
10701070

10711071
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
@@ -1079,8 +1079,13 @@ func TestGetOidcPollCode(t *testing.T) {
10791079
assert.Equal(t, 0, ts.RequestsRemaining())
10801080

10811081
//unhappy path
1082-
ts.ExpectAPIRequest(http.StatusBadGateway, func(req any) []byte {
1083-
return jsonMarshal(message.PreAuthResponse{Data: message.PreAuthData{PollToken: expectedCode, LoginURL: "https://example.com"}})
1082+
ts.ExpectAPIRequest(http.StatusInternalServerError, func(req any) []byte {
1083+
return jsonMarshal(message.APIResponse[message.PreAuthData]{
1084+
Errors: message.APIErrors{{
1085+
Code: "ERR_INTERNAL_SERVER_ERROR",
1086+
Message: "internal server error",
1087+
}},
1088+
})
10841089
})
10851090
resp, err = client.EndpointPreAuth(ctx)
10861091
require.Error(t, err)
@@ -1099,7 +1104,7 @@ func TestDoOidcPoll(t *testing.T) {
10991104
t.Cleanup(func() { ts.Close() })
11001105
const expectedCode = "123456"
11011106
ts.ExpectAPIRequest(http.StatusOK, func(r any) []byte {
1102-
return jsonMarshal(message.EndpointAuthPollResponse{Data: message.EndpointAuthPollData{
1107+
return jsonMarshal(message.APIResponse[message.EndpointAuthPollData]{Data: message.EndpointAuthPollData{
11031108
Status: message.EndpointAuthStarted,
11041109
EnrollmentCode: "",
11051110
}})
@@ -1126,7 +1131,7 @@ func TestDoOidcPoll(t *testing.T) {
11261131

11271132
//complete path
11281133
ts.ExpectAPIRequest(http.StatusOK, func(r any) []byte {
1129-
return jsonMarshal(message.EndpointAuthPollResponse{Data: message.EndpointAuthPollData{
1134+
return jsonMarshal(message.APIResponse[message.EndpointAuthPollData]{Data: message.EndpointAuthPollData{
11301135
Status: message.EndpointAuthCompleted,
11311136
EnrollmentCode: "deadbeef",
11321137
}})

0 commit comments

Comments
 (0)