aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--internal/api/api_test.go45
1 files changed, 28 insertions, 17 deletions
diff --git a/internal/api/api_test.go b/internal/api/api_test.go
index 50c11327..caa620b4 100644
--- a/internal/api/api_test.go
+++ b/internal/api/api_test.go
@@ -5,8 +5,10 @@
package api
import (
+ "bytes"
"context"
"encoding/json"
+ "errors"
"net/http"
"net/http/httptest"
"strconv"
@@ -206,23 +208,12 @@ func TestServePackage(t *testing.T) {
}
if test.want != nil {
- switch want := test.want.(type) {
- case *Package:
- var got Package
- if err := json.Unmarshal(w.Body.Bytes(), &got); err != nil {
- t.Fatalf("json.Unmarshal Package: %v", err)
- }
- if diff := cmp.Diff(want, &got); diff != "" {
- t.Errorf("mismatch (-want +got):\n%s", diff)
- }
- case *Error:
- var got Error
- if err := json.Unmarshal(w.Body.Bytes(), &got); err != nil {
- t.Fatalf("json.Unmarshal Error: %v. Body: %s", err, w.Body.String())
- }
- if diff := cmp.Diff(want, &got); diff != "" {
- t.Errorf("mismatch (-want +got):\n%s", diff)
- }
+ got, err := unmarshalResponse[Package](w.Body.Bytes())
+ if err != nil {
+ t.Fatal(err)
+ }
+ if diff := cmp.Diff(test.want, got); diff != "" {
+ t.Errorf("mismatch (-want +got):\n%s", diff)
}
}
})
@@ -710,3 +701,23 @@ func TestServePackageSymbols(t *testing.T) {
})
}
}
+
+// unmarshalResponse unmarshals an API response into either
+// a *T or an *Error.
+func unmarshalResponse[T any](data []byte) (any, error) {
+ d := json.NewDecoder(bytes.NewReader(data))
+ d.DisallowUnknownFields()
+ var t T
+ err1 := d.Decode(&t)
+ if err1 == nil {
+ return &t, nil
+ }
+ d = json.NewDecoder(bytes.NewReader(data))
+ d.DisallowUnknownFields()
+ var e Error
+ err2 := d.Decode(&e)
+ if err2 == nil {
+ return &e, nil
+ }
+ return nil, errors.Join(err1, err2)
+}