From 842fdc1923f8537b6ed0e5b70861afbb0c843032 Mon Sep 17 00:00:00 2001 From: Prad Nukala Date: Thu, 8 Jan 2026 20:16:50 -0500 Subject: [PATCH] refactor(keybase): switch to native SQLite serialization and deserialization --- TODO.md | 49 ++++++------ go.mod | 1 + go.sum | 2 + internal/keybase/conn.go | 156 +++++++++++-------------------------- internal/types/generate.go | 31 +++++++- main.go | 107 ++++++++++++++++++++++--- 6 files changed, 202 insertions(+), 144 deletions(-) diff --git a/TODO.md b/TODO.md index 1614d12..fc94c26 100644 --- a/TODO.md +++ b/TODO.md @@ -232,13 +232,13 @@ The following files implement UCAN v1.0.0-rc.1 using the official go-ucan librar ### 4.1 Key Share Actions -- [ ] `CreateKeyShare(ctx, params) (*KeyShareResult, error)` -- [ ] `ListKeyShares(ctx) ([]KeyShareResult, error)` -- [ ] `GetKeyShareByID(ctx, shareID) (*KeyShareResult, error)` -- [ ] `GetKeyShareByKeyID(ctx, keyID) (*KeyShareResult, error)` -- [ ] `RotateKeyShare(ctx, shareID) error` -- [ ] `ArchiveKeyShare(ctx, shareID) error` -- [ ] `DeleteKeyShare(ctx, shareID) error` +- [x] `CreateKeyShare(ctx, params) (*KeyShareResult, error)` +- [x] `ListKeyShares(ctx) ([]KeyShareResult, error)` +- [x] `GetKeyShareByID(ctx, shareID) (*KeyShareResult, error)` +- [x] `GetKeyShareByKeyID(ctx, keyID) (*KeyShareResult, error)` +- [x] `RotateKeyShare(ctx, shareID) error` +- [x] `ArchiveKeyShare(ctx, shareID) error` +- [x] `DeleteKeyShare(ctx, shareID) error` ### 4.2 UCAN Token Actions (v1.0.0-rc.1) @@ -290,12 +290,12 @@ The following files implement UCAN v1.0.0-rc.1 using the official go-ucan librar ### 4.6 Account Actions (Extend Existing) -- [ ] `CreateAccount(ctx, params) (*AccountResult, error)` -- [ ] `ListAccountsByChain(ctx, chainID) ([]AccountResult, error)` -- [ ] `GetDefaultAccount(ctx, chainID) (*AccountResult, error)` -- [ ] `SetDefaultAccount(ctx, accountID, chainID) error` -- [ ] `UpdateAccountLabel(ctx, accountID, label) error` -- [ ] `DeleteAccount(ctx, accountID) error` +- [x] `CreateAccount(ctx, params) (*AccountResult, error)` +- [x] `ListAccountsByChain(ctx, chainID) ([]AccountResult, error)` +- [x] `GetDefaultAccount(ctx, chainID) (*AccountResult, error)` +- [x] `SetDefaultAccount(ctx, accountID, chainID) error` +- [x] `UpdateAccountLabel(ctx, accountID, label) error` +- [x] `DeleteAccount(ctx, accountID) error` ### 4.7 Credential Actions (Extend Existing) @@ -327,23 +327,24 @@ The following files implement UCAN v1.0.0-rc.1 using the official go-ucan librar ### 5.1 Key Share Storage -- [ ] Parse key share data from MPC protocol -- [ ] Encrypt share data before storage -- [ ] Store public key and chain code -- [ ] Track party index and threshold +- [x] Parse key share data from MPC protocol - `KeyShareInput` in generate +- [x] Store public key and chain code - `CreateKeyShare` action +- [x] Track party index and threshold - stored in `key_shares` table +- [ ] Encrypt share data before storage - PRF key derivation needed ### 5.2 Account Derivation +- [x] Basic address derivation from public key - `deriveCosmosAddress()` +- [x] Create initial account during generate - `createInitialAccount()` - [ ] Implement BIP44 derivation path parsing -- [ ] Derive addresses from public keys - [ ] Support multiple chains (Cosmos 118, Ethereum 60) -- [ ] Generate proper address encoding per chain +- [ ] Generate proper bech32 address encoding per chain ### 5.3 Key Rotation -- [ ] Implement key rotation workflow -- [ ] Archive old shares -- [ ] Update status transitions +- [x] Implement key rotation workflow - `RotateKeyShare` action +- [x] Archive old shares - `ArchiveKeyShare` action +- [x] Status transitions - managed in database - [ ] Handle rotation failures gracefully --- @@ -364,11 +365,13 @@ The following files implement UCAN v1.0.0-rc.1 using the official go-ucan librar ### 6.2 Extend `generate` Function +- [x] Accept optional MPC keyshare data in input +- [x] Create initial keyshare if provided +- [x] Create initial account from keyshare - [ ] Parse WebAuthn credential properly (CBOR/COSE format) - [ ] Extract public key from credential - [ ] Create initial verification method - [ ] Create initial credential record -- [ ] Generate initial account (if key share provided) ### 6.3 Signing Function diff --git a/go.mod b/go.mod index 272588a..f849419 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect github.com/dustinxie/ecc v0.0.0-20210511000915-959544187564 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/gtank/merlin v0.1.1 // indirect github.com/klauspost/cpuid/v2 v2.2.10 // indirect github.com/mimoo/StrobeGo v0.0.0-20181016162300-f8f6d4d2b643 // indirect diff --git a/go.sum b/go.sum index fd744c6..e83332e 100644 --- a/go.sum +++ b/go.sum @@ -27,6 +27,8 @@ github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7z github.com/go-yaml/yaml v2.1.0+incompatible/go.mod h1:w2MrLa16VYP0jy6N7M5kHaCkaLENm+P+Tv+MfurjSw0= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gtank/merlin v0.1.1 h1:eQ90iG7K9pOhtereWsmyRJ6RAwcP4tHTDBHXNg+u5is= diff --git a/internal/keybase/conn.go b/internal/keybase/conn.go index fdccd93..43ed1c2 100644 --- a/internal/keybase/conn.go +++ b/internal/keybase/conn.go @@ -1,4 +1,3 @@ -// Package keybase contains the SQLite database for cryptographic keys. package keybase import ( @@ -6,18 +5,22 @@ import ( "database/sql" "encoding/json" "fmt" - "strings" "sync" "enclave/internal/migrations" - _ "github.com/ncruces/go-sqlite3/driver" + "github.com/ncruces/go-sqlite3" + "github.com/ncruces/go-sqlite3/driver" _ "github.com/ncruces/go-sqlite3/embed" + "github.com/ncruces/go-sqlite3/ext/hash" + "github.com/ncruces/go-sqlite3/ext/serdes" + "github.com/ncruces/go-sqlite3/ext/uuid" ) // Keybase encapsulates the encrypted key storage database. type Keybase struct { db *sql.DB + conn *sqlite3.Conn // raw connection for serdes queries *Queries did string didID int64 @@ -38,19 +41,32 @@ func Open() (*Keybase, error) { return instance, nil } - conn, err := sql.Open("sqlite3", ":memory:") + var rawConn *sqlite3.Conn + initCallback := func(conn *sqlite3.Conn) error { + rawConn = conn + if err := hash.Register(conn); err != nil { + return fmt.Errorf("register hash extension: %w", err) + } + if err := uuid.Register(conn); err != nil { + return fmt.Errorf("register uuid extension: %w", err) + } + return nil + } + + db, err := driver.Open(":memory:", initCallback) if err != nil { return nil, fmt.Errorf("keybase: open database: %w", err) } - if _, err := conn.Exec(migrations.SchemaSQL); err != nil { - conn.Close() + if _, err := db.Exec(migrations.SchemaSQL); err != nil { + db.Close() return nil, fmt.Errorf("keybase: init schema: %w", err) } instance = &Keybase{ - db: conn, - queries: New(conn), + db: db, + conn: rawConn, + queries: New(db), } return instance, nil @@ -169,10 +185,21 @@ func (k *Keybase) Initialize(ctx context.Context, credentialBytes []byte) (strin // Load restores the database state from serialized bytes and sets the current DID. func (k *Keybase) Load(ctx context.Context, data []byte) (string, error) { - if len(data) < 10 { + if len(data) < 100 { return "", fmt.Errorf("keybase: invalid database format") } + k.mu.Lock() + defer k.mu.Unlock() + + if k.conn == nil { + return "", fmt.Errorf("keybase: database not initialized") + } + + if err := serdes.Deserialize(k.conn, "main", data); err != nil { + return "", fmt.Errorf("keybase: deserialize database: %w", err) + } + docs, err := k.queries.ListAllDIDs(ctx) if err != nil { return "", fmt.Errorf("keybase: list DIDs: %w", err) @@ -182,127 +209,34 @@ func (k *Keybase) Load(ctx context.Context, data []byte) (string, error) { return "", fmt.Errorf("keybase: no DID found in database") } - k.mu.Lock() k.did = docs[0].Did k.didID = docs[0].ID - k.mu.Unlock() return k.did, nil } -// Serialize exports the database state as bytes. +// Serialize exports the database state as bytes using native SQLite serialization. func (k *Keybase) Serialize() ([]byte, error) { k.mu.RLock() defer k.mu.RUnlock() - if k.db == nil { + if k.conn == nil { return nil, fmt.Errorf("keybase: database not initialized") } - return k.exportDump() -} - -func (k *Keybase) exportDump() ([]byte, error) { - var dump strings.Builder - dump.WriteString(migrations.SchemaSQL + "\n") - - tables := []string{ - "did_documents", "verification_methods", "credentials", - "key_shares", "accounts", "ucan_tokens", "ucan_revocations", - "sessions", "services", "grants", "delegations", "sync_checkpoints", - } - - for _, table := range tables { - if err := k.exportTable(&dump, table); err != nil { - continue - } - } - - return []byte(dump.String()), nil -} - -func (k *Keybase) exportTable(dump *strings.Builder, table string) error { - rows, err := k.db.Query(fmt.Sprintf("SELECT * FROM %s", table)) - if err != nil { - return err - } - defer rows.Close() - - cols, err := rows.Columns() - if err != nil { - return err - } - - for rows.Next() { - values := make([]any, len(cols)) - valuePtrs := make([]any, len(cols)) - for i := range values { - valuePtrs[i] = &values[i] - } - - if err := rows.Scan(valuePtrs...); err != nil { - continue - } - - dump.WriteString(fmt.Sprintf("INSERT INTO %s (", table)) - dump.WriteString(strings.Join(cols, ", ")) - dump.WriteString(") VALUES (") - - for i, val := range values { - if i > 0 { - dump.WriteString(", ") - } - dump.WriteString(formatSQLValue(val)) - } - dump.WriteString(");\n") - } - - return rows.Err() -} - -func formatSQLValue(val any) string { - if val == nil { - return "NULL" - } - - switch v := val.(type) { - case int64: - return fmt.Sprintf("%d", v) - case float64: - return fmt.Sprintf("%f", v) - case bool: - if v { - return "1" - } - return "0" - case []byte: - return fmt.Sprintf("'%s'", escapeSQLString(string(v))) - case string: - return fmt.Sprintf("'%s'", escapeSQLString(v)) - default: - return fmt.Sprintf("'%s'", escapeSQLString(fmt.Sprintf("%v", v))) - } -} - -func escapeSQLString(s string) string { - return strings.ReplaceAll(s, "'", "''") + return serdes.Serialize(k.conn, "main") } func (k *Keybase) RestoreFromDump(data []byte) error { k.mu.Lock() defer k.mu.Unlock() - statements := strings.Split(string(data), ";\n") - for _, stmt := range statements { - stmt = strings.TrimSpace(stmt) - if stmt == "" || strings.HasPrefix(stmt, "--") { - continue - } - if strings.HasPrefix(stmt, "INSERT INTO") { - if _, err := k.db.Exec(stmt); err != nil { - return fmt.Errorf("keybase: failed to execute statement: %w", err) - } - } + if k.conn == nil { + return fmt.Errorf("keybase: database not initialized") + } + + if err := serdes.Deserialize(k.conn, "main", data); err != nil { + return fmt.Errorf("keybase: deserialize database: %w", err) } docs, err := k.queries.ListAllDIDs(context.Background()) diff --git a/internal/types/generate.go b/internal/types/generate.go index 9cd550c..2803f36 100644 --- a/internal/types/generate.go +++ b/internal/types/generate.go @@ -2,11 +2,40 @@ package types // GenerateInput represents the input for the generate function type GenerateInput struct { - Credential string `json:"credential"` // Base64-encoded PublicKeyCredential + Credential string `json:"credential"` // Base64-encoded WebAuthn credential + + // MPC keyshare data (optional - if provided, creates initial keyshare and account) + KeyShare *KeyShareInput `json:"key_share,omitempty"` +} + +// KeyShareInput represents MPC keyshare data for initialization +type KeyShareInput struct { + KeyID string `json:"key_id"` + PartyIndex int64 `json:"party_index"` + Threshold int64 `json:"threshold"` + TotalParties int64 `json:"total_parties"` + Curve string `json:"curve"` + ShareData string `json:"share_data"` + PublicKey string `json:"public_key"` + ChainCode string `json:"chain_code,omitempty"` + DerivationPath string `json:"derivation_path,omitempty"` } // GenerateOutput represents the output of the generate function type GenerateOutput struct { DID string `json:"did"` Database []byte `json:"database"` + + // KeyShare info if a keyshare was provided + KeyShareID string `json:"key_share_id,omitempty"` + + // Account info if an account was created + Account *AccountInfo `json:"account,omitempty"` +} + +// AccountInfo represents created account information +type AccountInfo struct { + Address string `json:"address"` + ChainID string `json:"chain_id"` + CoinType int64 `json:"coin_type"` } diff --git a/main.go b/main.go index 08b1042..8cf697e 100644 --- a/main.go +++ b/main.go @@ -3,6 +3,7 @@ package main import ( "context" "encoding/base64" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -67,14 +68,14 @@ func generate() int32 { return 1 } - did, err := initializeDatabase(credentialBytes) + result, err := initializeDatabase(credentialBytes, input.KeyShare) if err != nil { pdk.SetError(fmt.Errorf("generate: failed to initialize database: %w", err)) return 1 } state.SetInitialized(true) - state.SetDID(did) + state.SetDID(result.DID) dbBytes, err := serializeDatabase() if err != nil { @@ -83,8 +84,10 @@ func generate() int32 { } output := types.GenerateOutput{ - DID: did, - Database: dbBytes, + DID: result.DID, + Database: dbBytes, + KeyShareID: result.KeyShareID, + Account: result.Account, } if err := pdk.OutputJSON(output); err != nil { @@ -92,7 +95,7 @@ func generate() int32 { return 1 } - pdk.Log(pdk.LogInfo, fmt.Sprintf("generate: created DID %s", did)) + pdk.Log(pdk.LogInfo, fmt.Sprintf("generate: created DID %s", result.DID)) return 0 } @@ -238,20 +241,106 @@ func query() int32 { return 0 } -func initializeDatabase(credentialBytes []byte) (string, error) { +type initResult struct { + DID string + KeyShareID string + Account *types.AccountInfo +} + +func initializeDatabase(credentialBytes []byte, keyShareInput *types.KeyShareInput) (*initResult, error) { kb, err := keybase.Open() if err != nil { - return "", fmt.Errorf("open database: %w", err) + return nil, fmt.Errorf("open database: %w", err) } ctx := context.Background() did, err := kb.Initialize(ctx, credentialBytes) if err != nil { - return "", fmt.Errorf("initialize: %w", err) + return nil, fmt.Errorf("initialize: %w", err) + } + + result := &initResult{DID: did} + + if keyShareInput != nil { + keyShareID, account, err := createInitialKeyShare(ctx, keyShareInput) + if err != nil { + pdk.Log(pdk.LogWarn, fmt.Sprintf("initializeDatabase: failed to create keyshare: %s", err)) + } else { + result.KeyShareID = keyShareID + result.Account = account + pdk.Log(pdk.LogInfo, fmt.Sprintf("initializeDatabase: created keyshare %s", keyShareID)) + } } pdk.Log(pdk.LogDebug, "initializeDatabase: created schema and initial records") - return did, nil + return result, nil +} + +func createInitialKeyShare(ctx context.Context, input *types.KeyShareInput) (string, *types.AccountInfo, error) { + am, err := keybase.NewActionManager() + if err != nil { + return "", nil, fmt.Errorf("action manager: %w", err) + } + + ks, err := am.CreateKeyShare(ctx, keybase.NewKeyShareInput{ + KeyID: input.KeyID, + PartyIndex: input.PartyIndex, + Threshold: input.Threshold, + TotalParties: input.TotalParties, + Curve: input.Curve, + ShareData: input.ShareData, + PublicKey: input.PublicKey, + ChainCode: input.ChainCode, + DerivationPath: input.DerivationPath, + }) + if err != nil { + return "", nil, fmt.Errorf("create keyshare: %w", err) + } + + account, err := createInitialAccount(ctx, am, ks.ID, input.PublicKey) + if err != nil { + pdk.Log(pdk.LogWarn, fmt.Sprintf("createInitialKeyShare: failed to create account: %s", err)) + return ks.ShareID, nil, nil + } + + return ks.ShareID, account, nil +} + +func createInitialAccount(ctx context.Context, am *keybase.ActionManager, keyShareID int64, publicKey string) (*types.AccountInfo, error) { + address := deriveCosmosAddress(publicKey) + if address == "" { + return nil, fmt.Errorf("failed to derive address from public key") + } + + acc, err := am.CreateAccount(ctx, keybase.NewAccountInput{ + KeyShareID: keyShareID, + Address: address, + ChainID: "sonr-testnet-1", + CoinType: 118, + AccountIndex: 0, + AddressIndex: 0, + Label: "Default Account", + }) + if err != nil { + return nil, fmt.Errorf("create account: %w", err) + } + + return &types.AccountInfo{ + Address: acc.Address, + ChainID: acc.ChainID, + CoinType: acc.CoinType, + }, nil +} + +func deriveCosmosAddress(publicKeyHex string) string { + if publicKeyHex == "" { + return "" + } + pubBytes, err := hex.DecodeString(publicKeyHex) + if err != nil || len(pubBytes) < 20 { + return "" + } + return fmt.Sprintf("snr1%x", pubBytes[:20]) } func serializeDatabase() ([]byte, error) {