Add support for more key types

This commit is contained in:
cheddar 2025-02-20 22:53:18 -05:00
parent b1e4a0cf72
commit 001a4b4ac5
No known key found for this signature in database
14 changed files with 98 additions and 48 deletions

View file

@ -28,7 +28,7 @@ func (dir inMemoryDirectory) GetKey(ctx context.Context, keyId string, _ string)
return entry.toAlg()
}
func (dir inMemoryDirectory) RegisterKey(key crypto.PublicKey, alg string, userId string) (string, error) {
func (dir inMemoryDirectory) RegisterKey(key crypto.PublicKey, userId string) (string, error) {
keyId, err := generateKeyId()
if err != nil {
@ -36,7 +36,6 @@ func (dir inMemoryDirectory) RegisterKey(key crypto.PublicKey, alg string, userI
}
dir.records[keyId] = keyEntry{
Alg: alg,
PublicKey: key,
UserId: userId,
}

View file

@ -2,15 +2,19 @@ package keydirectory
import (
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"fmt"
"reflect"
"github.com/common-fate/httpsig/alg_ecdsa"
"github.com/common-fate/httpsig/alg_ed25519"
"github.com/common-fate/httpsig/alg_rsa"
"github.com/common-fate/httpsig/verifier"
)
type keyEntry struct {
Alg string
PublicKey crypto.PublicKey
UserId string
}
@ -19,14 +23,24 @@ func (k keyEntry) toAlg() (verifier.Algorithm, error) {
var alg verifier.Algorithm
var err error
switch k.Alg {
case "ed25519":
switch k.PublicKey.(type) {
case ed25519.PublicKey:
alg = alg_ed25519.Ed25519{
PublicKey: k.PublicKey.(ed25519.PublicKey),
Attrs: k.UserId,
}
case *rsa.PublicKey:
alg = alg_rsa.RSAPKCS256{
PublicKey: k.PublicKey.(*rsa.PublicKey),
Attrs: k.UserId,
}
case *ecdsa.PublicKey:
alg = alg_ecdsa.P256{
PublicKey: k.PublicKey.(*ecdsa.PublicKey),
Attrs: k.UserId,
}
default:
err = fmt.Errorf("unknown algoritm: %s", k.Alg)
err = fmt.Errorf("unknown key type: %s", reflect.TypeOf(k.PublicKey))
}
return alg, err

View file

@ -8,5 +8,5 @@ import (
type RegistrationDirectory interface {
verifier.KeyDirectory
RegisterKey(key crypto.PublicKey, alg string, userId string) (string, error)
RegisterKey(key crypto.PublicKey, userId string) (string, error)
}

View file

@ -3,9 +3,8 @@ package keydirectory
import (
"context"
"crypto"
"crypto/ed25519"
"crypto/x509"
"database/sql"
"fmt"
"github.com/common-fate/httpsig/verifier"
_ "github.com/mattn/go-sqlite3"
@ -40,7 +39,7 @@ func InitSqlite(dbPath string) (*dbWrapper, error) {
func (dir *dbWrapper) GetKey(ctx context.Context, keyId string, _ string) (verifier.Algorithm, error) {
db := dir.db
query := "select userId, alg, publicKey from keys where keyId = ?"
query := "select userId, publicKey from keys where keyId = ?"
stmt, err := db.Prepare(query)
@ -51,28 +50,23 @@ func (dir *dbWrapper) GetKey(ctx context.Context, keyId string, _ string) (verif
defer stmt.Close()
var userId string
var alg string
var keyBytes []byte
row := stmt.QueryRow(keyId)
err = row.Scan(&userId, &alg, &keyBytes)
err = row.Scan(&userId, &keyBytes)
if err != nil {
return nil, err
}
var publicKey crypto.PublicKey
publicKey, err := x509.ParsePKIXPublicKey(keyBytes)
switch alg {
case "ed25519":
publicKey = ed25519.PublicKey(keyBytes)
default:
return nil, fmt.Errorf("unknown algorithm: %s", alg)
if err != nil {
return nil, err
}
keyEntry := keyEntry{
Alg: alg,
UserId: userId,
PublicKey: publicKey,
}
@ -80,7 +74,7 @@ func (dir *dbWrapper) GetKey(ctx context.Context, keyId string, _ string) (verif
return keyEntry.toAlg()
}
func (dir *dbWrapper) RegisterKey(key crypto.PublicKey, alg string, userId string) (string, error) {
func (dir *dbWrapper) RegisterKey(key crypto.PublicKey, userId string) (string, error) {
db := dir.db
keyId, err := generateKeyId()
@ -89,18 +83,15 @@ func (dir *dbWrapper) RegisterKey(key crypto.PublicKey, alg string, userId strin
return "", err
}
stmt := "insert into keys(keyId, userId, alg, publicKey) values (?, ?, ?, ?)"
stmt := "insert into keys(keyId, userId, publicKey) values (?, ?, ?)"
var keyBytes []byte
keyBytes, err := x509.MarshalPKIXPublicKey(key)
switch alg {
case "ed25519":
keyBytes = []byte(key.(ed25519.PublicKey))
default:
return "", fmt.Errorf("unknown algorithm: %s", alg)
if err != nil {
return "", err
}
_, err = db.Exec(stmt, keyId, userId, alg, keyBytes)
_, err = db.Exec(stmt, keyId, userId, keyBytes)
if err != nil {
return "", err