package server

import (
	"context"
	"crypto"
	"encoding/json"
	"fmt"
	"net/http"

	"crispbyte.dev/sig-auth/keydirectory"
	"github.com/common-fate/httpsig"
	"github.com/common-fate/httpsig/inmemory"
	"golang.org/x/crypto/ssh"
)

func Start(keyDir keydirectory.RegistrationDirectory) error {
	mux := http.NewServeMux()

	validationOptions := httpsig.DefaultValidationOpts()
	delete(validationOptions.RequiredCoveredComponents, "content-digest")

	verifier := httpsig.Middleware(httpsig.MiddlewareOpts{
		NonceStorage: inmemory.NewNonceStorage(),
		KeyDirectory: keyDir,
		Tag:          "auth",
		Scheme:       "http",
		Authority:    "localhost:8001",
		Validation:   &validationOptions,

		OnValidationError: func(ctx context.Context, err error) {
			fmt.Printf("validation error: %s\n", err)
		},

		OnDeriveSigningString: func(ctx context.Context, stringToSign string) {
			fmt.Printf("string to sign:\n%s\n", stringToSign)
		},
	})

	verifyHandler := verifier(getDefaultHandler())

	handler := rewriteHeaders(verifyHandler)

	mux.Handle("/auth", handler)
	mux.Handle("/register", getRegistrationHandler(keyDir))

	err := http.ListenAndServe("localhost:8080", mux)

	return err
}

func getDefaultHandler() http.Handler {
	handler := func(w http.ResponseWriter, r *http.Request) {
		attr := httpsig.AttributesFromContext(r.Context()).(string)

		w.Header().Add("Remote-User", attr)
		msg := fmt.Sprintf("hello, %s!", attr)
		w.Write([]byte(msg))
	}

	return http.HandlerFunc(handler)
}

func getRegistrationHandler(keyDir keydirectory.RegistrationDirectory) http.Handler {
	handler := func(w http.ResponseWriter, r *http.Request) {
		if r.Method != "POST" {
			http.Error(w, "Bad request", 400)
			return
		}

		var request RegisterRequest

		err := json.NewDecoder(r.Body).Decode(&request)

		if err != nil {
			fmt.Println(err)
			http.Error(w, fmt.Sprintf("Bad request - %s", err), 400)
			return
		}

		key, err := parsePublicKey(request.Key)

		if err != nil {
			fmt.Println(err)
			http.Error(w, fmt.Sprintf("Bad request - %s", err), 400)
			return
		}

		if !isValidKeyType(key) {
			fmt.Println("Attempted to register invalid key type")
			http.Error(w, "Invalid key type", 400)
			return
		}

		fmt.Printf("Registering key for %s\n", request.UserId)

		keyId, err := keyDir.RegisterKey(key, request.UserId)

		if err != nil {
			fmt.Println(err)
			http.Error(w, fmt.Sprintf("Server error - %s", err), 500)
			return
		}

		w.Write([]byte(keyId))
	}

	return http.HandlerFunc(handler)
}

func parsePublicKey(input string) (crypto.PublicKey, error) {
	pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(input))

	if err != nil {
		return nil, err
	}

	return pk.(ssh.CryptoPublicKey).CryptoPublicKey(), err
}