cheddar
·
2025-02-22
server.go
1package server
2
3import (
4 "context"
5 "crypto"
6 "encoding/json"
7 "fmt"
8 "net/http"
9
10 "crispbyte.dev/sig-auth/keydirectory"
11 "github.com/common-fate/httpsig"
12 "github.com/common-fate/httpsig/inmemory"
13 "golang.org/x/crypto/ssh"
14)
15
16func Start(keyDir keydirectory.RegistrationDirectory) error {
17 mux := http.NewServeMux()
18
19 validationOptions := httpsig.DefaultValidationOpts()
20 delete(validationOptions.RequiredCoveredComponents, "content-digest")
21
22 verifier := httpsig.Middleware(httpsig.MiddlewareOpts{
23 NonceStorage: inmemory.NewNonceStorage(),
24 KeyDirectory: keyDir,
25 Tag: "auth",
26 Scheme: "http",
27 Authority: "localhost:8001",
28 Validation: &validationOptions,
29
30 OnValidationError: func(ctx context.Context, err error) {
31 fmt.Printf("validation error: %s\n", err)
32 },
33
34 OnDeriveSigningString: func(ctx context.Context, stringToSign string) {
35 fmt.Printf("string to sign:\n%s\n", stringToSign)
36 },
37 })
38
39 verifyHandler := verifier(getDefaultHandler())
40
41 handler := rewriteHeaders(verifyHandler)
42
43 mux.Handle("/auth", handler)
44 mux.Handle("/register", getRegistrationHandler(keyDir))
45
46 err := http.ListenAndServe("localhost:8080", mux)
47
48 return err
49}
50
51func getDefaultHandler() http.Handler {
52 handler := func(w http.ResponseWriter, r *http.Request) {
53 attr := httpsig.AttributesFromContext(r.Context()).(string)
54
55 w.Header().Add("Remote-User", attr)
56 msg := fmt.Sprintf("hello, %s!", attr)
57 w.Write([]byte(msg))
58 }
59
60 return http.HandlerFunc(handler)
61}
62
63func getRegistrationHandler(keyDir keydirectory.RegistrationDirectory) http.Handler {
64 handler := func(w http.ResponseWriter, r *http.Request) {
65 if r.Method != "POST" {
66 http.Error(w, "Bad request", 400)
67 return
68 }
69
70 var request RegisterRequest
71
72 err := json.NewDecoder(r.Body).Decode(&request)
73
74 if err != nil {
75 fmt.Println(err)
76 http.Error(w, fmt.Sprintf("Bad request - %s", err), 400)
77 return
78 }
79
80 key, err := parsePublicKey(request.Key)
81
82 if err != nil {
83 fmt.Println(err)
84 http.Error(w, fmt.Sprintf("Bad request - %s", err), 400)
85 return
86 }
87
88 if !isValidKeyType(key) {
89 fmt.Println("Attempted to register invalid key type")
90 http.Error(w, "Invalid key type", 400)
91 return
92 }
93
94 fmt.Printf("Registering key for %s\n", request.UserId)
95
96 keyId, err := keyDir.RegisterKey(key, request.UserId)
97
98 if err != nil {
99 fmt.Println(err)
100 http.Error(w, fmt.Sprintf("Server error - %s", err), 500)
101 return
102 }
103
104 w.Write([]byte(keyId))
105 }
106
107 return http.HandlerFunc(handler)
108}
109
110func parsePublicKey(input string) (crypto.PublicKey, error) {
111 pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(input))
112
113 if err != nil {
114 return nil, err
115 }
116
117 return pk.(ssh.CryptoPublicKey).CryptoPublicKey(), err
118}