sig-auth.git

git clone https://git.crispbyte.dev/sig-auth.git

sig-auth.git / server
cheddar  ·  2025-02-22

server.go

  1package server
  2
  3import (
  4	"context"
  5	"crypto"
  6	"encoding/json"
  7	"fmt"
  8	"net/http"
  9
 10	"crispbyte.dev/sig-auth/keydirectory"
 11	"github.com/common-fate/httpsig"
 12	"github.com/common-fate/httpsig/inmemory"
 13	"golang.org/x/crypto/ssh"
 14)
 15
 16func Start(keyDir keydirectory.RegistrationDirectory) error {
 17	mux := http.NewServeMux()
 18
 19	validationOptions := httpsig.DefaultValidationOpts()
 20	delete(validationOptions.RequiredCoveredComponents, "content-digest")
 21
 22	verifier := httpsig.Middleware(httpsig.MiddlewareOpts{
 23		NonceStorage: inmemory.NewNonceStorage(),
 24		KeyDirectory: keyDir,
 25		Tag:          "auth",
 26		Scheme:       "http",
 27		Authority:    "localhost:8001",
 28		Validation:   &validationOptions,
 29
 30		OnValidationError: func(ctx context.Context, err error) {
 31			fmt.Printf("validation error: %s\n", err)
 32		},
 33
 34		OnDeriveSigningString: func(ctx context.Context, stringToSign string) {
 35			fmt.Printf("string to sign:\n%s\n", stringToSign)
 36		},
 37	})
 38
 39	verifyHandler := verifier(getDefaultHandler())
 40
 41	handler := rewriteHeaders(verifyHandler)
 42
 43	mux.Handle("/auth", handler)
 44	mux.Handle("/register", getRegistrationHandler(keyDir))
 45
 46	err := http.ListenAndServe("localhost:8080", mux)
 47
 48	return err
 49}
 50
 51func getDefaultHandler() http.Handler {
 52	handler := func(w http.ResponseWriter, r *http.Request) {
 53		attr := httpsig.AttributesFromContext(r.Context()).(string)
 54
 55		w.Header().Add("Remote-User", attr)
 56		msg := fmt.Sprintf("hello, %s!", attr)
 57		w.Write([]byte(msg))
 58	}
 59
 60	return http.HandlerFunc(handler)
 61}
 62
 63func getRegistrationHandler(keyDir keydirectory.RegistrationDirectory) http.Handler {
 64	handler := func(w http.ResponseWriter, r *http.Request) {
 65		if r.Method != "POST" {
 66			http.Error(w, "Bad request", 400)
 67			return
 68		}
 69
 70		var request RegisterRequest
 71
 72		err := json.NewDecoder(r.Body).Decode(&request)
 73
 74		if err != nil {
 75			fmt.Println(err)
 76			http.Error(w, fmt.Sprintf("Bad request - %s", err), 400)
 77			return
 78		}
 79
 80		key, err := parsePublicKey(request.Key)
 81
 82		if err != nil {
 83			fmt.Println(err)
 84			http.Error(w, fmt.Sprintf("Bad request - %s", err), 400)
 85			return
 86		}
 87
 88		if !isValidKeyType(key) {
 89			fmt.Println("Attempted to register invalid key type")
 90			http.Error(w, "Invalid key type", 400)
 91			return
 92		}
 93
 94		fmt.Printf("Registering key for %s\n", request.UserId)
 95
 96		keyId, err := keyDir.RegisterKey(key, request.UserId)
 97
 98		if err != nil {
 99			fmt.Println(err)
100			http.Error(w, fmt.Sprintf("Server error - %s", err), 500)
101			return
102		}
103
104		w.Write([]byte(keyId))
105	}
106
107	return http.HandlerFunc(handler)
108}
109
110func parsePublicKey(input string) (crypto.PublicKey, error) {
111	pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(input))
112
113	if err != nil {
114		return nil, err
115	}
116
117	return pk.(ssh.CryptoPublicKey).CryptoPublicKey(), err
118}