diff --git a/toolkit/server/exectx/middlewares.go b/toolkit/server/exectx/middlewares.go index c2d96c4..c431905 100644 --- a/toolkit/server/exectx/middlewares.go +++ b/toolkit/server/exectx/middlewares.go @@ -18,7 +18,7 @@ func ExtractMW(next http.Handler, serviceDID did.DID) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctn, err := bearer.ExtractBearerContainer(r.Header) if errors.Is(err, bearer.ErrNoUcan) { - http.Error(w, "no UCAN auth", http.StatusBadRequest) + http.Error(w, "no UCAN auth", http.StatusUnauthorized) return } if errors.Is(err, bearer.ErrContainerMalformed) { @@ -34,7 +34,7 @@ func ExtractMW(next http.Handler, serviceDID did.DID) http.Handler { // prepare a UcanCtx from the container, for further evaluation in the server pipeline ucanCtx, err := FromContainer(ctn) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + http.Error(w, err.Error(), http.StatusUnauthorized) return } diff --git a/toolkit/server/exectx/middlewares_test.go b/toolkit/server/exectx/middlewares_test.go new file mode 100644 index 0000000..00dd652 --- /dev/null +++ b/toolkit/server/exectx/middlewares_test.go @@ -0,0 +1,128 @@ +package exectx + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/ipfs/go-cid" + "github.com/stretchr/testify/require" + "github.com/ucan-wg/go-ucan/did/didtest" + "github.com/ucan-wg/go-ucan/pkg/container" + "github.com/ucan-wg/go-ucan/token/delegation" + "github.com/ucan-wg/go-ucan/token/invocation" +) + +func TestExtractMW(t *testing.T) { + const service = didtest.PersonaAlice + const client = didtest.PersonaBob + const cmd = "/foo/bar" + + for _, tc := range []struct { + name string + addHeaderFn func(func(key string, value string)) + expectedStatusCode int + successful bool + }{ + { + name: "no auth", + addHeaderFn: func(f func(key string, value string)) {}, + expectedStatusCode: http.StatusUnauthorized, + successful: false, + }, + { + name: "wrong kind of auth", + addHeaderFn: func(f func(key string, value string)) { + f("Authorization", "Basic foobar") + }, + expectedStatusCode: http.StatusUnauthorized, + successful: false, + }, + { + name: "invalid container", + addHeaderFn: func(f func(key string, value string)) { + f("Authorization", "Bearer foobar") + }, + expectedStatusCode: http.StatusBadRequest, + successful: false, + }, + { + name: "valid containter, for incorrect service", + addHeaderFn: func(f func(key string, value string)) { + cont := container.NewWriter() + + const service2 = didtest.PersonaCarol + + dlg, _ := delegation.Root(service2.DID(), client.DID(), cmd, nil) + dlgByte, dlgCid, _ := dlg.ToSealed(service2.PrivKey()) + cont.AddSealed(dlgByte) + + inv, _ := invocation.New(client.DID(), cmd, service2.DID(), []cid.Cid{dlgCid}) + invBytes, _, _ := inv.ToSealed(client.PrivKey()) + cont.AddSealed(invBytes) + + contB64, _ := cont.ToBase64StdPadding() + + f("Authorization", "Bearer "+contB64) + }, + expectedStatusCode: http.StatusUnauthorized, + successful: false, + }, + { + name: "valid containter, missing invocation", + addHeaderFn: func(f func(key string, value string)) { + cont := container.NewWriter() + + dlg, _ := delegation.Root(service.DID(), client.DID(), cmd, nil) + dlgByte, _, _ := dlg.ToSealed(service.PrivKey()) + cont.AddSealed(dlgByte) + + contB64, _ := cont.ToBase64StdPadding() + + f("Authorization", "Bearer "+contB64) + }, + expectedStatusCode: http.StatusUnauthorized, + successful: false, + }, + { + name: "valid containter, valid tokens", + addHeaderFn: func(f func(key string, value string)) { + cont := container.NewWriter() + + dlg, _ := delegation.Root(service.DID(), client.DID(), cmd, nil) + dlgByte, dlgCid, _ := dlg.ToSealed(service.PrivKey()) + cont.AddSealed(dlgByte) + + inv, _ := invocation.New(client.DID(), cmd, service.DID(), []cid.Cid{dlgCid}) + invBytes, _, _ := inv.ToSealed(client.PrivKey()) + cont.AddSealed(invBytes) + + contB64, _ := cont.ToBase64StdPadding() + + f("Authorization", "Bearer "+contB64) + }, + expectedStatusCode: http.StatusOK, + successful: true, + }, + } { + t.Run(tc.name, func(t *testing.T) { + var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, has := FromContext(r.Context()) + require.Equal(t, tc.successful, has) + + _, _ = io.WriteString(w, "OK") + }) + handler = ExtractMW(handler, service.DID()) + + req := httptest.NewRequest("GET", "https://example.com/foo", nil) + tc.addHeaderFn(req.Header.Set) + + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + require.Equal(t, tc.expectedStatusCode, w.Code) + + }) + } +}