Files
motr-enclave/cmd/enclave/main.go

616 lines
14 KiB
Go

//go:build wasip1
package main
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"strings"
"enclave/internal/crypto/bip44"
"enclave/internal/crypto/mpc"
"enclave/internal/keybase"
"enclave/internal/state"
"enclave/internal/types"
"github.com/extism/go-pdk"
)
func main() { state.Default() }
//go:wasmexport ping
func ping() int32 {
pdk.Log(pdk.LogInfo, "ping: received request")
var input types.PingInput
if err := pdk.InputJSON(&input); err != nil {
output := types.PingOutput{
Success: false,
Message: fmt.Sprintf("failed to parse input: %s", err),
}
pdk.OutputJSON(output)
return 0
}
output := types.PingOutput{
Success: true,
Message: "pong",
Echo: input.Message,
}
if err := pdk.OutputJSON(output); err != nil {
pdk.Log(pdk.LogError, fmt.Sprintf("ping: failed to output: %s", err))
return 1
}
pdk.Log(pdk.LogInfo, fmt.Sprintf("ping: responded with echo=%s", input.Message))
return 0
}
//go:wasmexport generate
func generate() int32 {
pdk.Log(pdk.LogInfo, "generate: starting")
var input types.GenerateInput
if err := pdk.InputJSON(&input); err != nil {
pdk.SetError(fmt.Errorf("generate: failed to parse input: %w", err))
return 1
}
if input.Credential == "" {
pdk.SetError(errors.New("generate: credential is required"))
return 1
}
credentialBytes, err := base64.StdEncoding.DecodeString(input.Credential)
if err != nil {
pdk.SetError(fmt.Errorf("generate: invalid base64 credential: %w", err))
return 1
}
result, err := initializeWithMPC(credentialBytes)
if err != nil {
pdk.SetError(fmt.Errorf("generate: %w", err))
return 1
}
state.SetInitialized(true)
state.SetDID(result.DID)
dbBytes, err := serializeDatabase()
if err != nil {
pdk.SetError(fmt.Errorf("generate: serialize: %w", err))
return 1
}
output := types.GenerateOutput{
DID: result.DID,
Database: dbBytes,
}
if err := pdk.OutputJSON(output); err != nil {
pdk.SetError(fmt.Errorf("generate: output: %w", err))
return 1
}
pdk.Log(pdk.LogInfo, fmt.Sprintf("generate: created DID %s with enclave %s", result.DID, result.EnclaveID))
return 0
}
//go:wasmexport load
func load() int32 {
pdk.Log(pdk.LogInfo, "load: loading database from buffer")
var input types.LoadInput
if err := pdk.InputJSON(&input); err != nil {
pdk.SetError(fmt.Errorf("load: failed to parse input: %w", err))
return 1
}
if len(input.Database) == 0 {
pdk.SetError(errors.New("load: database buffer is required"))
return 1
}
did, err := loadDatabase(input.Database)
if err != nil {
output := types.LoadOutput{
Success: false,
Error: err.Error(),
}
pdk.OutputJSON(output)
return 1
}
state.SetInitialized(true)
state.SetDID(did)
output := types.LoadOutput{
Success: true,
DID: did,
}
if err := pdk.OutputJSON(output); err != nil {
pdk.SetError(fmt.Errorf("load: failed to output result: %w", err))
return 1
}
pdk.Log(pdk.LogInfo, fmt.Sprintf("load: loaded database for DID %s", did))
return 0
}
//go:wasmexport exec
func exec() int32 {
pdk.Log(pdk.LogInfo, "exec: executing action")
if !state.IsInitialized() {
output := types.ExecOutput{Success: false, Error: "database not initialized, call generate or load first"}
pdk.OutputJSON(output)
return 0
}
var input types.ExecInput
if err := pdk.InputJSON(&input); err != nil {
output := types.ExecOutput{Success: false, Error: fmt.Sprintf("failed to parse input: %s", err)}
pdk.OutputJSON(output)
return 0
}
if input.Filter == "" {
output := types.ExecOutput{Success: false, Error: "filter is required"}
pdk.OutputJSON(output)
return 0
}
params, err := parseFilter(input.Filter)
if err != nil {
output := types.ExecOutput{Success: false, Error: fmt.Sprintf("invalid filter: %s", err)}
pdk.OutputJSON(output)
return 0
}
if input.Token != "" {
if err := validateUCAN(input.Token, params); err != nil {
output := types.ExecOutput{
Success: false,
Error: fmt.Sprintf("authorization failed: %s", err.Error()),
}
pdk.OutputJSON(output)
return 1
}
}
result, err := executeAction(params)
if err != nil {
output := types.ExecOutput{
Success: false,
Error: err.Error(),
}
pdk.OutputJSON(output)
return 1
}
output := types.ExecOutput{
Success: true,
Result: result,
}
pdk.OutputJSON(output)
pdk.Log(pdk.LogInfo, fmt.Sprintf("exec: completed %s on %s", params.Action, params.Resource))
return 0
}
//go:wasmexport query
func query() int32 {
pdk.Log(pdk.LogInfo, "query: resolving DID document")
if !state.IsInitialized() {
pdk.SetError(errors.New("database not initialized, call generate or load first"))
return 1
}
var input types.QueryInput
if err := pdk.InputJSON(&input); err != nil {
pdk.SetError(fmt.Errorf("query: failed to parse input: %w", err))
return 1
}
if input.DID == "" {
input.DID = state.GetDID()
}
if !strings.HasPrefix(input.DID, "did:") {
pdk.SetError(errors.New("query: invalid DID format"))
return 1
}
output, err := resolveDID(input.DID)
if err != nil {
pdk.SetError(fmt.Errorf("query: failed to resolve DID: %w", err))
return 1
}
if err := pdk.OutputJSON(output); err != nil {
pdk.SetError(fmt.Errorf("query: failed to output result: %w", err))
return 1
}
pdk.Log(pdk.LogInfo, fmt.Sprintf("query: resolved DID %s", input.DID))
return 0
}
type initResult struct {
DID string
EnclaveID string
PublicKey string
Accounts []types.AccountInfo
}
func initializeWithMPC(credentialBytes []byte) (*initResult, error) {
kb, err := keybase.Open()
if err != nil {
return nil, fmt.Errorf("open database: %w", err)
}
ctx := context.Background()
did, err := kb.Initialize(ctx, credentialBytes)
if err != nil {
return nil, fmt.Errorf("initialize DID: %w", err)
}
simpleEnc, err := mpc.NewSimpleEnclave()
if err != nil {
return nil, fmt.Errorf("generate enclave: %w", err)
}
enclaveID := fmt.Sprintf("enc_%x", credentialBytes[:8])
am, err := keybase.NewActionManager()
if err != nil {
return nil, fmt.Errorf("action manager: %w", err)
}
enc, err := am.CreateEnclave(ctx, keybase.NewEnclaveInput{
EnclaveID: enclaveID,
PublicKeyHex: simpleEnc.PubKeyHex(),
PublicKey: simpleEnc.PubKeyBytes(),
ValShare: simpleEnc.GetShare1(),
UserShare: simpleEnc.GetShare2(),
Nonce: simpleEnc.GetNonce(),
Curve: string(simpleEnc.GetCurve()),
})
if err != nil {
return nil, fmt.Errorf("store enclave: %w", err)
}
accounts, err := createDefaultAccounts(ctx, am, enc.ID, simpleEnc.PubKeyBytes())
if err != nil {
pdk.Log(pdk.LogWarn, fmt.Sprintf("createDefaultAccounts: %s", err))
accounts = []types.AccountInfo{}
}
return &initResult{
DID: did,
EnclaveID: enclaveID,
PublicKey: simpleEnc.PubKeyHex(),
Accounts: accounts,
}, nil
}
func createDefaultAccounts(ctx context.Context, am *keybase.ActionManager, enclaveID int64, pubKeyBytes []byte) ([]types.AccountInfo, error) {
chains := []string{"sonr", "ethereum", "bitcoin"}
derivedAccounts, err := bip44.DeriveAccounts(pubKeyBytes, chains)
if err != nil {
return nil, fmt.Errorf("derive accounts: %w", err)
}
accounts := make([]types.AccountInfo, 0, len(derivedAccounts))
for i, derived := range derivedAccounts {
isDefault := int64(0)
if i == 0 {
isDefault = 1
}
acc, err := am.CreateAccount(ctx, keybase.NewAccountInput{
EnclaveID: enclaveID,
Address: derived.Address,
ChainID: derived.ChainID,
CoinType: int64(derived.CoinType),
AccountIndex: int64(derived.AccountIndex),
AddressIndex: int64(derived.AddressIndex),
Label: derived.ChainID,
IsDefault: isDefault,
})
if err != nil {
continue
}
accounts = append(accounts, types.AccountInfo{
Address: acc.Address,
ChainID: acc.ChainID,
CoinType: acc.CoinType,
})
}
return accounts, nil
}
func serializeDatabase() ([]byte, error) {
kb := keybase.Get()
if kb == nil {
return nil, errors.New("database not initialized")
}
return kb.Serialize()
}
func loadDatabase(data []byte) (string, error) {
if len(data) < 10 {
return "", errors.New("invalid database format")
}
kb, err := keybase.Open()
if err != nil {
return "", fmt.Errorf("open database: %w", err)
}
ctx := context.Background()
did, err := kb.Load(ctx, data)
if err != nil {
return "", fmt.Errorf("load DID: %w", err)
}
pdk.Log(pdk.LogDebug, "loadDatabase: database loaded successfully")
return did, nil
}
func parseFilter(filter string) (*types.FilterParams, error) {
params := &types.FilterParams{}
parts := strings.FieldsSeq(filter)
for part := range parts {
kv := strings.SplitN(part, ":", 2)
if len(kv) != 2 {
continue
}
key, value := kv[0], kv[1]
switch key {
case "resource":
params.Resource = value
case "action":
params.Action = value
case "subject":
params.Subject = value
}
}
if params.Resource == "" {
return nil, errors.New("resource is required")
}
if params.Action == "" {
return nil, errors.New("action is required")
}
return params, nil
}
func validateUCAN(token string, params *types.FilterParams) error {
if token == "" {
return errors.New("token is required")
}
parts := strings.Split(token, ".")
if len(parts) != 3 {
return errors.New("invalid token format: expected JWT with 3 parts")
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return fmt.Errorf("invalid token payload: %w", err)
}
var claims map[string]any
if err := json.Unmarshal(payload, &claims); err != nil {
return fmt.Errorf("invalid token claims: %w", err)
}
if exp, ok := claims["exp"].(float64); ok {
if int64(exp) < currentUnixTime() {
return errors.New("token has expired")
}
}
if nbf, ok := claims["nbf"].(float64); ok {
if int64(nbf) > currentUnixTime() {
return errors.New("token is not yet valid")
}
}
if aud, ok := claims["aud"].(string); ok {
currentDID := state.GetDID()
if currentDID != "" && aud != currentDID {
pdk.Log(pdk.LogDebug, fmt.Sprintf("validateUCAN: audience mismatch, expected %s got %s", currentDID, aud))
}
}
if att, ok := claims["att"].([]any); ok {
if !checkAttenuations(att, params.Resource, params.Action) {
return fmt.Errorf("token does not grant capability for %s:%s", params.Resource, params.Action)
}
}
am, err := keybase.NewActionManager()
if err == nil {
if cid, ok := claims["cid"].(string); ok {
ctx := context.Background()
revoked, err := am.IsDelegationRevoked(ctx, cid)
if err == nil && revoked {
return errors.New("token has been revoked")
}
}
}
pdk.Log(pdk.LogDebug, fmt.Sprintf("validateUCAN: validated token for %s:%s", params.Resource, params.Action))
return nil
}
func currentUnixTime() int64 {
return 0
}
func checkAttenuations(attenuations []any, resource, action string) bool {
for _, att := range attenuations {
attMap, ok := att.(map[string]any)
if !ok {
continue
}
with, ok := attMap["with"].(string)
if !ok {
continue
}
if !matchResource(with, resource) {
continue
}
can := attMap["can"]
if canStr, ok := can.(string); ok {
if canStr == "*" || canStr == action {
return true
}
} else if canSlice, ok := can.([]any); ok {
for _, c := range canSlice {
if cStr, ok := c.(string); ok {
if cStr == "*" || cStr == action {
return true
}
}
}
}
}
return false
}
func matchResource(pattern, resource string) bool {
if pattern == resource {
return true
}
if strings.HasSuffix(pattern, "/*") {
prefix := strings.TrimSuffix(pattern, "/*")
return strings.HasPrefix(resource, prefix)
}
if strings.Contains(pattern, "://") {
parts := strings.SplitN(pattern, "://", 2)
if len(parts) == 2 && parts[1] == resource {
return true
}
}
return false
}
func executeAction(params *types.FilterParams) (json.RawMessage, error) {
if params.Resource == "accounts" && params.Action == "balances" {
return fetchAccountBalances(params.Subject)
}
return keybase.Exec(context.Background(), params.Resource, params.Action, params.Subject)
}
func fetchAccountBalances(address string) (json.RawMessage, error) {
if address == "" {
address = state.GetDID()
}
apiBase, ok := state.GetConfig("api_endpoint")
if !ok {
apiBase = "https://api.sonr.io"
}
url := fmt.Sprintf("%s/cosmos/bank/v1beta1/balances/%s", apiBase, address)
pdk.Log(pdk.LogInfo, fmt.Sprintf("fetchAccountBalances: GET %s", url))
req := pdk.NewHTTPRequest(pdk.MethodGet, url)
req.SetHeader("Accept", "application/json")
res := req.Send()
status := res.Status()
if status < 200 || status >= 300 {
pdk.Log(pdk.LogError, fmt.Sprintf("fetchAccountBalances: HTTP %d", status))
return json.Marshal(map[string]any{
"error": "failed to fetch balances",
"status": status,
"address": address,
})
}
body := res.Body()
pdk.Log(pdk.LogDebug, fmt.Sprintf("fetchAccountBalances: received %d bytes", len(body)))
return body, nil
}
func resolveDID(did string) (*types.QueryOutput, error) {
am, err := keybase.NewActionManager()
if err != nil {
return nil, fmt.Errorf("action manager: %w", err)
}
ctx := context.Background()
doc, err := am.ResolveDID(ctx, did)
if err != nil {
return nil, fmt.Errorf("resolve DID: %w", err)
}
vms := make([]types.VerificationMethod, len(doc.VerificationMethods))
for i, vm := range doc.VerificationMethods {
vms[i] = types.VerificationMethod{
ID: vm.ID,
Type: vm.Type,
Controller: vm.Controller,
PublicKey: vm.PublicKey,
Purpose: vm.Purpose,
}
}
accounts := make([]types.Account, len(doc.Accounts))
for i, acc := range doc.Accounts {
accounts[i] = types.Account{
Address: acc.Address,
ChainID: acc.ChainID,
CoinType: int(acc.CoinType),
AccountIndex: int(acc.AccountIndex),
AddressIndex: int(acc.AddressIndex),
Label: acc.Label,
IsDefault: acc.IsDefault,
}
}
credentials := make([]types.Credential, len(doc.Credentials))
for i, cred := range doc.Credentials {
credentials[i] = types.Credential{
CredentialID: cred.CredentialID,
DeviceName: cred.DeviceName,
DeviceType: cred.DeviceType,
Authenticator: cred.Authenticator,
Transports: cred.Transports,
CreatedAt: cred.CreatedAt,
LastUsed: cred.LastUsed,
}
}
return &types.QueryOutput{
DID: doc.DID,
Controller: doc.Controller,
VerificationMethods: vms,
Accounts: accounts,
Credentials: credentials,
}, nil
}