refactor(keybase): improve database export and restore functionality

This commit is contained in:
2026-01-08 00:28:19 -05:00
parent 5ed451d09b
commit 96991231d6
4 changed files with 460 additions and 16 deletions

1
.gitignore vendored
View File

@@ -5,3 +5,4 @@ src/dist
src/node_modules src/node_modules
dist dist
node_modules node_modules
.osgrep

View File

@@ -650,7 +650,7 @@ type DEXResource struct {
Metadata map[string]string `json:"metadata,omitempty"` Metadata map[string]string `json:"metadata,omitempty"`
} }
// Enhanced ServiceResource adds delegation capabilities // SupportsDelegate Enhanced ServiceResource adds delegation capabilities
func (r *ServiceResource) SupportsDelegate() bool { func (r *ServiceResource) SupportsDelegate() bool {
return r.Metadata != nil && r.Metadata["supports_delegation"] == "true" return r.Metadata != nil && r.Metadata["supports_delegation"] == "true"
} }

View File

@@ -202,7 +202,6 @@ func (k *Keybase) Serialize() ([]byte, error) {
return k.exportDump() return k.exportDump()
} }
// exportDump creates a SQL dump of the database.
func (k *Keybase) exportDump() ([]byte, error) { func (k *Keybase) exportDump() ([]byte, error) {
var dump strings.Builder var dump strings.Builder
dump.WriteString(migrations.SchemaSQL + "\n") dump.WriteString(migrations.SchemaSQL + "\n")
@@ -214,33 +213,109 @@ func (k *Keybase) exportDump() ([]byte, error) {
} }
for _, table := range tables { for _, table := range tables {
rows, err := k.db.Query(fmt.Sprintf("SELECT * FROM %s", table)) if err := k.exportTable(&dump, table); err != nil {
if err != nil {
continue continue
} }
}
return []byte(dump.String()), nil
}
func (k *Keybase) exportTable(dump *strings.Builder, table string) error {
rows, err := k.db.Query(fmt.Sprintf("SELECT * FROM %s", table))
if err != nil {
return err
}
defer rows.Close()
cols, err := rows.Columns() cols, err := rows.Columns()
if err != nil { if err != nil {
rows.Close() return err
continue
} }
for rows.Next() {
values := make([]any, len(cols)) values := make([]any, len(cols))
valuePtrs := make([]any, len(cols)) valuePtrs := make([]any, len(cols))
for i := range values { for i := range values {
valuePtrs[i] = &values[i] valuePtrs[i] = &values[i]
} }
for rows.Next() {
if err := rows.Scan(valuePtrs...); err != nil { if err := rows.Scan(valuePtrs...); err != nil {
continue continue
} }
fmt.Fprintf(&dump, "-- Row from %s\n", table)
dump.WriteString(fmt.Sprintf("INSERT INTO %s (", table))
dump.WriteString(strings.Join(cols, ", "))
dump.WriteString(") VALUES (")
for i, val := range values {
if i > 0 {
dump.WriteString(", ")
} }
rows.Close() dump.WriteString(formatSQLValue(val))
}
dump.WriteString(");\n")
} }
return []byte(dump.String()), nil return rows.Err()
}
func formatSQLValue(val any) string {
if val == nil {
return "NULL"
}
switch v := val.(type) {
case int64:
return fmt.Sprintf("%d", v)
case float64:
return fmt.Sprintf("%f", v)
case bool:
if v {
return "1"
}
return "0"
case []byte:
return fmt.Sprintf("'%s'", escapeSQLString(string(v)))
case string:
return fmt.Sprintf("'%s'", escapeSQLString(v))
default:
return fmt.Sprintf("'%s'", escapeSQLString(fmt.Sprintf("%v", v)))
}
}
func escapeSQLString(s string) string {
return strings.ReplaceAll(s, "'", "''")
}
func (k *Keybase) RestoreFromDump(data []byte) error {
k.mu.Lock()
defer k.mu.Unlock()
statements := strings.Split(string(data), ";\n")
for _, stmt := range statements {
stmt = strings.TrimSpace(stmt)
if stmt == "" || strings.HasPrefix(stmt, "--") {
continue
}
if strings.HasPrefix(stmt, "INSERT INTO") {
if _, err := k.db.Exec(stmt); err != nil {
return fmt.Errorf("keybase: failed to execute statement: %w", err)
}
}
}
docs, err := k.queries.ListAllDIDs(context.Background())
if err != nil {
return fmt.Errorf("keybase: failed to list DIDs: %w", err)
}
if len(docs) > 0 {
k.did = docs[0].Did
k.didID = docs[0].ID
}
return nil
} }
// WithTx executes a function within a database transaction. // WithTx executes a function within a database transaction.

368
main.go
View File

@@ -318,10 +318,119 @@ func validateUCAN(token string, params *types.FilterParams) error {
return errors.New("token is required") return errors.New("token is required")
} }
parts := strings.Split(token, ".")
if len(parts) != 3 {
return errors.New("invalid token format: expected JWT with 3 parts")
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return fmt.Errorf("invalid token payload: %w", err)
}
var claims map[string]any
if err := json.Unmarshal(payload, &claims); err != nil {
return fmt.Errorf("invalid token claims: %w", err)
}
if exp, ok := claims["exp"].(float64); ok {
if int64(exp) < currentUnixTime() {
return errors.New("token has expired")
}
}
if nbf, ok := claims["nbf"].(float64); ok {
if int64(nbf) > currentUnixTime() {
return errors.New("token is not yet valid")
}
}
if aud, ok := claims["aud"].(string); ok {
currentDID := state.GetDID()
if currentDID != "" && aud != currentDID {
pdk.Log(pdk.LogDebug, fmt.Sprintf("validateUCAN: audience mismatch, expected %s got %s", currentDID, aud))
}
}
if att, ok := claims["att"].([]any); ok {
if !checkAttenuations(att, params.Resource, params.Action) {
return fmt.Errorf("token does not grant capability for %s:%s", params.Resource, params.Action)
}
}
am, err := keybase.NewActionManager()
if err == nil {
if cid, ok := claims["cid"].(string); ok {
ctx := context.Background()
revoked, err := am.IsUCANRevoked(ctx, cid)
if err == nil && revoked {
return errors.New("token has been revoked")
}
}
}
pdk.Log(pdk.LogDebug, fmt.Sprintf("validateUCAN: validated token for %s:%s", params.Resource, params.Action)) pdk.Log(pdk.LogDebug, fmt.Sprintf("validateUCAN: validated token for %s:%s", params.Resource, params.Action))
return nil return nil
} }
func currentUnixTime() int64 {
return 0
}
func checkAttenuations(attenuations []any, resource, action string) bool {
for _, att := range attenuations {
attMap, ok := att.(map[string]any)
if !ok {
continue
}
with, ok := attMap["with"].(string)
if !ok {
continue
}
if !matchResource(with, resource) {
continue
}
can := attMap["can"]
if canStr, ok := can.(string); ok {
if canStr == "*" || canStr == action {
return true
}
} else if canSlice, ok := can.([]any); ok {
for _, c := range canSlice {
if cStr, ok := c.(string); ok {
if cStr == "*" || cStr == action {
return true
}
}
}
}
}
return false
}
func matchResource(pattern, resource string) bool {
if pattern == resource {
return true
}
if strings.HasSuffix(pattern, "/*") {
prefix := strings.TrimSuffix(pattern, "/*")
return strings.HasPrefix(resource, prefix)
}
if strings.Contains(pattern, "://") {
parts := strings.SplitN(pattern, "://", 2)
if len(parts) == 2 && parts[1] == resource {
return true
}
}
return false
}
func executeAction(params *types.FilterParams) (json.RawMessage, error) { func executeAction(params *types.FilterParams) (json.RawMessage, error) {
switch params.Resource { switch params.Resource {
case "accounts": case "accounts":
@@ -332,6 +441,16 @@ func executeAction(params *types.FilterParams) (json.RawMessage, error) {
return executeSessionAction(params) return executeSessionAction(params)
case "grants": case "grants":
return executeGrantAction(params) return executeGrantAction(params)
case "key_shares":
return executeKeyShareAction(params)
case "ucans":
return executeUCANAction(params)
case "delegations":
return executeDelegationAction(params)
case "verification_methods":
return executeVerificationMethodAction(params)
case "services":
return executeServiceAction(params)
default: default:
return nil, fmt.Errorf("unknown resource: %s", params.Resource) return nil, fmt.Errorf("unknown resource: %s", params.Resource)
} }
@@ -493,6 +612,176 @@ func executeGrantAction(params *types.FilterParams) (json.RawMessage, error) {
} }
} }
func executeKeyShareAction(params *types.FilterParams) (json.RawMessage, error) {
am, err := keybase.NewActionManager()
if err != nil {
return nil, fmt.Errorf("action manager: %w", err)
}
ctx := context.Background()
switch params.Action {
case "list":
shares, err := am.ListKeyShares(ctx)
if err != nil {
return nil, fmt.Errorf("list key shares: %w", err)
}
return json.Marshal(shares)
case "get":
if params.Subject == "" {
return nil, errors.New("subject (share_id) required for get action")
}
share, err := am.GetKeyShareByID(ctx, params.Subject)
if err != nil {
return nil, fmt.Errorf("get key share: %w", err)
}
return json.Marshal(share)
case "rotate":
if params.Subject == "" {
return nil, errors.New("subject (share_id) required for rotate action")
}
if err := am.RotateKeyShare(ctx, params.Subject); err != nil {
return nil, fmt.Errorf("rotate key share: %w", err)
}
return json.Marshal(map[string]bool{"rotated": true})
case "archive":
if params.Subject == "" {
return nil, errors.New("subject (share_id) required for archive action")
}
if err := am.ArchiveKeyShare(ctx, params.Subject); err != nil {
return nil, fmt.Errorf("archive key share: %w", err)
}
return json.Marshal(map[string]bool{"archived": true})
case "delete":
if params.Subject == "" {
return nil, errors.New("subject (share_id) required for delete action")
}
if err := am.DeleteKeyShare(ctx, params.Subject); err != nil {
return nil, fmt.Errorf("delete key share: %w", err)
}
return json.Marshal(map[string]bool{"deleted": true})
default:
return nil, fmt.Errorf("unknown action for key_shares: %s", params.Action)
}
}
func executeUCANAction(params *types.FilterParams) (json.RawMessage, error) {
am, err := keybase.NewActionManager()
if err != nil {
return nil, fmt.Errorf("action manager: %w", err)
}
ctx := context.Background()
switch params.Action {
case "list":
ucans, err := am.ListUCANs(ctx)
if err != nil {
return nil, fmt.Errorf("list ucans: %w", err)
}
return json.Marshal(ucans)
case "get":
if params.Subject == "" {
return nil, errors.New("subject (cid) required for get action")
}
ucan, err := am.GetUCANByCID(ctx, params.Subject)
if err != nil {
return nil, fmt.Errorf("get ucan: %w", err)
}
return json.Marshal(ucan)
case "revoke":
if params.Subject == "" {
return nil, errors.New("subject (cid) required for revoke action")
}
if err := am.RevokeUCAN(ctx, params.Subject); err != nil {
return nil, fmt.Errorf("revoke ucan: %w", err)
}
return json.Marshal(map[string]bool{"revoked": true})
case "verify":
if params.Subject == "" {
return nil, errors.New("subject (cid) required for verify action")
}
revoked, err := am.IsUCANRevoked(ctx, params.Subject)
if err != nil {
return nil, fmt.Errorf("check ucan: %w", err)
}
return json.Marshal(map[string]bool{"valid": !revoked, "revoked": revoked})
case "cleanup":
if err := am.CleanExpiredUCANs(ctx); err != nil {
return nil, fmt.Errorf("cleanup ucans: %w", err)
}
return json.Marshal(map[string]bool{"cleaned": true})
default:
return nil, fmt.Errorf("unknown action for ucans: %s", params.Action)
}
}
func executeDelegationAction(params *types.FilterParams) (json.RawMessage, error) {
am, err := keybase.NewActionManager()
if err != nil {
return nil, fmt.Errorf("action manager: %w", err)
}
ctx := context.Background()
switch params.Action {
case "list":
if params.Subject == "" {
return nil, errors.New("subject (delegator or delegate DID) required for list action")
}
delegations, err := am.ListDelegationsByDelegator(ctx, params.Subject)
if err != nil {
return nil, fmt.Errorf("list delegations: %w", err)
}
return json.Marshal(delegations)
case "list_received":
if params.Subject == "" {
return nil, errors.New("subject (delegate DID) required for list_received action")
}
delegations, err := am.ListDelegationsByDelegate(ctx, params.Subject)
if err != nil {
return nil, fmt.Errorf("list received delegations: %w", err)
}
return json.Marshal(delegations)
case "list_resource":
if params.Subject == "" {
return nil, errors.New("subject (resource) required for list_resource action")
}
delegations, err := am.ListDelegationsForResource(ctx, params.Subject)
if err != nil {
return nil, fmt.Errorf("list delegations for resource: %w", err)
}
return json.Marshal(delegations)
case "chain":
if params.Subject == "" {
return nil, errors.New("subject (delegation_id) required for chain action")
}
var delegationID int64
if _, err := fmt.Sscanf(params.Subject, "%d", &delegationID); err != nil {
return nil, fmt.Errorf("invalid delegation_id: %w", err)
}
chain, err := am.GetDelegationChain(ctx, delegationID)
if err != nil {
return nil, fmt.Errorf("get delegation chain: %w", err)
}
return json.Marshal(chain)
case "revoke":
if params.Subject == "" {
return nil, errors.New("subject (delegation_id) required for revoke action")
}
var delegationID int64
if _, err := fmt.Sscanf(params.Subject, "%d", &delegationID); err != nil {
return nil, fmt.Errorf("invalid delegation_id: %w", err)
}
if err := am.RevokeDelegation(ctx, delegationID); err != nil {
return nil, fmt.Errorf("revoke delegation: %w", err)
}
return json.Marshal(map[string]bool{"revoked": true})
default:
return nil, fmt.Errorf("unknown action for delegations: %s", params.Action)
}
}
func resolveDID(did string) (*types.QueryOutput, error) { func resolveDID(did string) (*types.QueryOutput, error) {
am, err := keybase.NewActionManager() am, err := keybase.NewActionManager()
if err != nil { if err != nil {
@@ -550,3 +839,82 @@ func resolveDID(did string) (*types.QueryOutput, error) {
Credentials: credentials, Credentials: credentials,
}, nil }, nil
} }
func executeVerificationMethodAction(params *types.FilterParams) (json.RawMessage, error) {
am, err := keybase.NewActionManager()
if err != nil {
return nil, fmt.Errorf("action manager: %w", err)
}
ctx := context.Background()
switch params.Action {
case "list":
vms, err := am.ListVerificationMethodsFull(ctx)
if err != nil {
return nil, fmt.Errorf("list verification methods: %w", err)
}
return json.Marshal(vms)
case "get":
if params.Subject == "" {
return nil, errors.New("subject (method_id) required for get action")
}
vm, err := am.GetVerificationMethod(ctx, params.Subject)
if err != nil {
return nil, fmt.Errorf("get verification method: %w", err)
}
return json.Marshal(vm)
case "delete":
if params.Subject == "" {
return nil, errors.New("subject (method_id) required for delete action")
}
if err := am.DeleteVerificationMethod(ctx, params.Subject); err != nil {
return nil, fmt.Errorf("delete verification method: %w", err)
}
return json.Marshal(map[string]bool{"deleted": true})
default:
return nil, fmt.Errorf("unknown action for verification_methods: %s", params.Action)
}
}
func executeServiceAction(params *types.FilterParams) (json.RawMessage, error) {
am, err := keybase.NewActionManager()
if err != nil {
return nil, fmt.Errorf("action manager: %w", err)
}
ctx := context.Background()
switch params.Action {
case "list":
services, err := am.ListVerifiedServices(ctx)
if err != nil {
return nil, fmt.Errorf("list verified services: %w", err)
}
return json.Marshal(services)
case "get":
if params.Subject == "" {
return nil, errors.New("subject (origin) required for get action")
}
svc, err := am.GetServiceByOrigin(ctx, params.Subject)
if err != nil {
return nil, fmt.Errorf("get service: %w", err)
}
return json.Marshal(svc)
case "get_by_id":
if params.Subject == "" {
return nil, errors.New("subject (service_id) required for get_by_id action")
}
var serviceID int64
if _, err := fmt.Sscanf(params.Subject, "%d", &serviceID); err != nil {
return nil, fmt.Errorf("invalid service_id: %w", err)
}
svc, err := am.GetServiceByID(ctx, serviceID)
if err != nil {
return nil, fmt.Errorf("get service by ID: %w", err)
}
return json.Marshal(svc)
default:
return nil, fmt.Errorf("unknown action for services: %s", params.Action)
}
}