refactor(keybase): improve database export and restore functionality
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -5,3 +5,4 @@ src/dist
|
||||
src/node_modules
|
||||
dist
|
||||
node_modules
|
||||
.osgrep
|
||||
|
||||
@@ -650,7 +650,7 @@ type DEXResource struct {
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// Enhanced ServiceResource adds delegation capabilities
|
||||
// SupportsDelegate Enhanced ServiceResource adds delegation capabilities
|
||||
func (r *ServiceResource) SupportsDelegate() bool {
|
||||
return r.Metadata != nil && r.Metadata["supports_delegation"] == "true"
|
||||
}
|
||||
|
||||
@@ -202,7 +202,6 @@ func (k *Keybase) Serialize() ([]byte, error) {
|
||||
return k.exportDump()
|
||||
}
|
||||
|
||||
// exportDump creates a SQL dump of the database.
|
||||
func (k *Keybase) exportDump() ([]byte, error) {
|
||||
var dump strings.Builder
|
||||
dump.WriteString(migrations.SchemaSQL + "\n")
|
||||
@@ -214,33 +213,109 @@ func (k *Keybase) exportDump() ([]byte, error) {
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
rows, err := k.db.Query(fmt.Sprintf("SELECT * FROM %s", table))
|
||||
if err != nil {
|
||||
if err := k.exportTable(&dump, table); err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
cols, err := rows.Columns()
|
||||
if err != nil {
|
||||
rows.Close()
|
||||
continue
|
||||
}
|
||||
return []byte(dump.String()), nil
|
||||
}
|
||||
|
||||
func (k *Keybase) exportTable(dump *strings.Builder, table string) error {
|
||||
rows, err := k.db.Query(fmt.Sprintf("SELECT * FROM %s", table))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
cols, err := rows.Columns()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
values := make([]any, len(cols))
|
||||
valuePtrs := make([]any, len(cols))
|
||||
for i := range values {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
continue
|
||||
}
|
||||
fmt.Fprintf(&dump, "-- Row from %s\n", table)
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
continue
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
dump.WriteString(fmt.Sprintf("INSERT INTO %s (", table))
|
||||
dump.WriteString(strings.Join(cols, ", "))
|
||||
dump.WriteString(") VALUES (")
|
||||
|
||||
for i, val := range values {
|
||||
if i > 0 {
|
||||
dump.WriteString(", ")
|
||||
}
|
||||
dump.WriteString(formatSQLValue(val))
|
||||
}
|
||||
dump.WriteString(");\n")
|
||||
}
|
||||
|
||||
return []byte(dump.String()), nil
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
func formatSQLValue(val any) string {
|
||||
if val == nil {
|
||||
return "NULL"
|
||||
}
|
||||
|
||||
switch v := val.(type) {
|
||||
case int64:
|
||||
return fmt.Sprintf("%d", v)
|
||||
case float64:
|
||||
return fmt.Sprintf("%f", v)
|
||||
case bool:
|
||||
if v {
|
||||
return "1"
|
||||
}
|
||||
return "0"
|
||||
case []byte:
|
||||
return fmt.Sprintf("'%s'", escapeSQLString(string(v)))
|
||||
case string:
|
||||
return fmt.Sprintf("'%s'", escapeSQLString(v))
|
||||
default:
|
||||
return fmt.Sprintf("'%s'", escapeSQLString(fmt.Sprintf("%v", v)))
|
||||
}
|
||||
}
|
||||
|
||||
func escapeSQLString(s string) string {
|
||||
return strings.ReplaceAll(s, "'", "''")
|
||||
}
|
||||
|
||||
func (k *Keybase) RestoreFromDump(data []byte) error {
|
||||
k.mu.Lock()
|
||||
defer k.mu.Unlock()
|
||||
|
||||
statements := strings.Split(string(data), ";\n")
|
||||
for _, stmt := range statements {
|
||||
stmt = strings.TrimSpace(stmt)
|
||||
if stmt == "" || strings.HasPrefix(stmt, "--") {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(stmt, "INSERT INTO") {
|
||||
if _, err := k.db.Exec(stmt); err != nil {
|
||||
return fmt.Errorf("keybase: failed to execute statement: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
docs, err := k.queries.ListAllDIDs(context.Background())
|
||||
if err != nil {
|
||||
return fmt.Errorf("keybase: failed to list DIDs: %w", err)
|
||||
}
|
||||
|
||||
if len(docs) > 0 {
|
||||
k.did = docs[0].Did
|
||||
k.didID = docs[0].ID
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WithTx executes a function within a database transaction.
|
||||
|
||||
368
main.go
368
main.go
@@ -318,10 +318,119 @@ func validateUCAN(token string, params *types.FilterParams) error {
|
||||
return errors.New("token is required")
|
||||
}
|
||||
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return errors.New("invalid token format: expected JWT with 3 parts")
|
||||
}
|
||||
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid token payload: %w", err)
|
||||
}
|
||||
|
||||
var claims map[string]any
|
||||
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||
return fmt.Errorf("invalid token claims: %w", err)
|
||||
}
|
||||
|
||||
if exp, ok := claims["exp"].(float64); ok {
|
||||
if int64(exp) < currentUnixTime() {
|
||||
return errors.New("token has expired")
|
||||
}
|
||||
}
|
||||
|
||||
if nbf, ok := claims["nbf"].(float64); ok {
|
||||
if int64(nbf) > currentUnixTime() {
|
||||
return errors.New("token is not yet valid")
|
||||
}
|
||||
}
|
||||
|
||||
if aud, ok := claims["aud"].(string); ok {
|
||||
currentDID := state.GetDID()
|
||||
if currentDID != "" && aud != currentDID {
|
||||
pdk.Log(pdk.LogDebug, fmt.Sprintf("validateUCAN: audience mismatch, expected %s got %s", currentDID, aud))
|
||||
}
|
||||
}
|
||||
|
||||
if att, ok := claims["att"].([]any); ok {
|
||||
if !checkAttenuations(att, params.Resource, params.Action) {
|
||||
return fmt.Errorf("token does not grant capability for %s:%s", params.Resource, params.Action)
|
||||
}
|
||||
}
|
||||
|
||||
am, err := keybase.NewActionManager()
|
||||
if err == nil {
|
||||
if cid, ok := claims["cid"].(string); ok {
|
||||
ctx := context.Background()
|
||||
revoked, err := am.IsUCANRevoked(ctx, cid)
|
||||
if err == nil && revoked {
|
||||
return errors.New("token has been revoked")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pdk.Log(pdk.LogDebug, fmt.Sprintf("validateUCAN: validated token for %s:%s", params.Resource, params.Action))
|
||||
return nil
|
||||
}
|
||||
|
||||
func currentUnixTime() int64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func checkAttenuations(attenuations []any, resource, action string) bool {
|
||||
for _, att := range attenuations {
|
||||
attMap, ok := att.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
with, ok := attMap["with"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if !matchResource(with, resource) {
|
||||
continue
|
||||
}
|
||||
|
||||
can := attMap["can"]
|
||||
if canStr, ok := can.(string); ok {
|
||||
if canStr == "*" || canStr == action {
|
||||
return true
|
||||
}
|
||||
} else if canSlice, ok := can.([]any); ok {
|
||||
for _, c := range canSlice {
|
||||
if cStr, ok := c.(string); ok {
|
||||
if cStr == "*" || cStr == action {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func matchResource(pattern, resource string) bool {
|
||||
if pattern == resource {
|
||||
return true
|
||||
}
|
||||
|
||||
if strings.HasSuffix(pattern, "/*") {
|
||||
prefix := strings.TrimSuffix(pattern, "/*")
|
||||
return strings.HasPrefix(resource, prefix)
|
||||
}
|
||||
|
||||
if strings.Contains(pattern, "://") {
|
||||
parts := strings.SplitN(pattern, "://", 2)
|
||||
if len(parts) == 2 && parts[1] == resource {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func executeAction(params *types.FilterParams) (json.RawMessage, error) {
|
||||
switch params.Resource {
|
||||
case "accounts":
|
||||
@@ -332,6 +441,16 @@ func executeAction(params *types.FilterParams) (json.RawMessage, error) {
|
||||
return executeSessionAction(params)
|
||||
case "grants":
|
||||
return executeGrantAction(params)
|
||||
case "key_shares":
|
||||
return executeKeyShareAction(params)
|
||||
case "ucans":
|
||||
return executeUCANAction(params)
|
||||
case "delegations":
|
||||
return executeDelegationAction(params)
|
||||
case "verification_methods":
|
||||
return executeVerificationMethodAction(params)
|
||||
case "services":
|
||||
return executeServiceAction(params)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown resource: %s", params.Resource)
|
||||
}
|
||||
@@ -493,6 +612,176 @@ func executeGrantAction(params *types.FilterParams) (json.RawMessage, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func executeKeyShareAction(params *types.FilterParams) (json.RawMessage, error) {
|
||||
am, err := keybase.NewActionManager()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("action manager: %w", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
switch params.Action {
|
||||
case "list":
|
||||
shares, err := am.ListKeyShares(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list key shares: %w", err)
|
||||
}
|
||||
return json.Marshal(shares)
|
||||
case "get":
|
||||
if params.Subject == "" {
|
||||
return nil, errors.New("subject (share_id) required for get action")
|
||||
}
|
||||
share, err := am.GetKeyShareByID(ctx, params.Subject)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get key share: %w", err)
|
||||
}
|
||||
return json.Marshal(share)
|
||||
case "rotate":
|
||||
if params.Subject == "" {
|
||||
return nil, errors.New("subject (share_id) required for rotate action")
|
||||
}
|
||||
if err := am.RotateKeyShare(ctx, params.Subject); err != nil {
|
||||
return nil, fmt.Errorf("rotate key share: %w", err)
|
||||
}
|
||||
return json.Marshal(map[string]bool{"rotated": true})
|
||||
case "archive":
|
||||
if params.Subject == "" {
|
||||
return nil, errors.New("subject (share_id) required for archive action")
|
||||
}
|
||||
if err := am.ArchiveKeyShare(ctx, params.Subject); err != nil {
|
||||
return nil, fmt.Errorf("archive key share: %w", err)
|
||||
}
|
||||
return json.Marshal(map[string]bool{"archived": true})
|
||||
case "delete":
|
||||
if params.Subject == "" {
|
||||
return nil, errors.New("subject (share_id) required for delete action")
|
||||
}
|
||||
if err := am.DeleteKeyShare(ctx, params.Subject); err != nil {
|
||||
return nil, fmt.Errorf("delete key share: %w", err)
|
||||
}
|
||||
return json.Marshal(map[string]bool{"deleted": true})
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown action for key_shares: %s", params.Action)
|
||||
}
|
||||
}
|
||||
|
||||
func executeUCANAction(params *types.FilterParams) (json.RawMessage, error) {
|
||||
am, err := keybase.NewActionManager()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("action manager: %w", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
switch params.Action {
|
||||
case "list":
|
||||
ucans, err := am.ListUCANs(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list ucans: %w", err)
|
||||
}
|
||||
return json.Marshal(ucans)
|
||||
case "get":
|
||||
if params.Subject == "" {
|
||||
return nil, errors.New("subject (cid) required for get action")
|
||||
}
|
||||
ucan, err := am.GetUCANByCID(ctx, params.Subject)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get ucan: %w", err)
|
||||
}
|
||||
return json.Marshal(ucan)
|
||||
case "revoke":
|
||||
if params.Subject == "" {
|
||||
return nil, errors.New("subject (cid) required for revoke action")
|
||||
}
|
||||
if err := am.RevokeUCAN(ctx, params.Subject); err != nil {
|
||||
return nil, fmt.Errorf("revoke ucan: %w", err)
|
||||
}
|
||||
return json.Marshal(map[string]bool{"revoked": true})
|
||||
case "verify":
|
||||
if params.Subject == "" {
|
||||
return nil, errors.New("subject (cid) required for verify action")
|
||||
}
|
||||
revoked, err := am.IsUCANRevoked(ctx, params.Subject)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check ucan: %w", err)
|
||||
}
|
||||
return json.Marshal(map[string]bool{"valid": !revoked, "revoked": revoked})
|
||||
case "cleanup":
|
||||
if err := am.CleanExpiredUCANs(ctx); err != nil {
|
||||
return nil, fmt.Errorf("cleanup ucans: %w", err)
|
||||
}
|
||||
return json.Marshal(map[string]bool{"cleaned": true})
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown action for ucans: %s", params.Action)
|
||||
}
|
||||
}
|
||||
|
||||
func executeDelegationAction(params *types.FilterParams) (json.RawMessage, error) {
|
||||
am, err := keybase.NewActionManager()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("action manager: %w", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
switch params.Action {
|
||||
case "list":
|
||||
if params.Subject == "" {
|
||||
return nil, errors.New("subject (delegator or delegate DID) required for list action")
|
||||
}
|
||||
delegations, err := am.ListDelegationsByDelegator(ctx, params.Subject)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list delegations: %w", err)
|
||||
}
|
||||
return json.Marshal(delegations)
|
||||
case "list_received":
|
||||
if params.Subject == "" {
|
||||
return nil, errors.New("subject (delegate DID) required for list_received action")
|
||||
}
|
||||
delegations, err := am.ListDelegationsByDelegate(ctx, params.Subject)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list received delegations: %w", err)
|
||||
}
|
||||
return json.Marshal(delegations)
|
||||
case "list_resource":
|
||||
if params.Subject == "" {
|
||||
return nil, errors.New("subject (resource) required for list_resource action")
|
||||
}
|
||||
delegations, err := am.ListDelegationsForResource(ctx, params.Subject)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list delegations for resource: %w", err)
|
||||
}
|
||||
return json.Marshal(delegations)
|
||||
case "chain":
|
||||
if params.Subject == "" {
|
||||
return nil, errors.New("subject (delegation_id) required for chain action")
|
||||
}
|
||||
var delegationID int64
|
||||
if _, err := fmt.Sscanf(params.Subject, "%d", &delegationID); err != nil {
|
||||
return nil, fmt.Errorf("invalid delegation_id: %w", err)
|
||||
}
|
||||
chain, err := am.GetDelegationChain(ctx, delegationID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get delegation chain: %w", err)
|
||||
}
|
||||
return json.Marshal(chain)
|
||||
case "revoke":
|
||||
if params.Subject == "" {
|
||||
return nil, errors.New("subject (delegation_id) required for revoke action")
|
||||
}
|
||||
var delegationID int64
|
||||
if _, err := fmt.Sscanf(params.Subject, "%d", &delegationID); err != nil {
|
||||
return nil, fmt.Errorf("invalid delegation_id: %w", err)
|
||||
}
|
||||
if err := am.RevokeDelegation(ctx, delegationID); err != nil {
|
||||
return nil, fmt.Errorf("revoke delegation: %w", err)
|
||||
}
|
||||
return json.Marshal(map[string]bool{"revoked": true})
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown action for delegations: %s", params.Action)
|
||||
}
|
||||
}
|
||||
|
||||
func resolveDID(did string) (*types.QueryOutput, error) {
|
||||
am, err := keybase.NewActionManager()
|
||||
if err != nil {
|
||||
@@ -550,3 +839,82 @@ func resolveDID(did string) (*types.QueryOutput, error) {
|
||||
Credentials: credentials,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func executeVerificationMethodAction(params *types.FilterParams) (json.RawMessage, error) {
|
||||
am, err := keybase.NewActionManager()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("action manager: %w", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
switch params.Action {
|
||||
case "list":
|
||||
vms, err := am.ListVerificationMethodsFull(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list verification methods: %w", err)
|
||||
}
|
||||
return json.Marshal(vms)
|
||||
case "get":
|
||||
if params.Subject == "" {
|
||||
return nil, errors.New("subject (method_id) required for get action")
|
||||
}
|
||||
vm, err := am.GetVerificationMethod(ctx, params.Subject)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get verification method: %w", err)
|
||||
}
|
||||
return json.Marshal(vm)
|
||||
case "delete":
|
||||
if params.Subject == "" {
|
||||
return nil, errors.New("subject (method_id) required for delete action")
|
||||
}
|
||||
if err := am.DeleteVerificationMethod(ctx, params.Subject); err != nil {
|
||||
return nil, fmt.Errorf("delete verification method: %w", err)
|
||||
}
|
||||
return json.Marshal(map[string]bool{"deleted": true})
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown action for verification_methods: %s", params.Action)
|
||||
}
|
||||
}
|
||||
|
||||
func executeServiceAction(params *types.FilterParams) (json.RawMessage, error) {
|
||||
am, err := keybase.NewActionManager()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("action manager: %w", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
switch params.Action {
|
||||
case "list":
|
||||
services, err := am.ListVerifiedServices(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list verified services: %w", err)
|
||||
}
|
||||
return json.Marshal(services)
|
||||
case "get":
|
||||
if params.Subject == "" {
|
||||
return nil, errors.New("subject (origin) required for get action")
|
||||
}
|
||||
svc, err := am.GetServiceByOrigin(ctx, params.Subject)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get service: %w", err)
|
||||
}
|
||||
return json.Marshal(svc)
|
||||
case "get_by_id":
|
||||
if params.Subject == "" {
|
||||
return nil, errors.New("subject (service_id) required for get_by_id action")
|
||||
}
|
||||
var serviceID int64
|
||||
if _, err := fmt.Sscanf(params.Subject, "%d", &serviceID); err != nil {
|
||||
return nil, fmt.Errorf("invalid service_id: %w", err)
|
||||
}
|
||||
svc, err := am.GetServiceByID(ctx, serviceID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get service by ID: %w", err)
|
||||
}
|
||||
return json.Marshal(svc)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown action for services: %s", params.Action)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user