diff --git a/ucl/userbuiltin.go b/ucl/userbuiltin.go index 591a73d..46c493e 100644 --- a/ucl/userbuiltin.go +++ b/ucl/userbuiltin.go @@ -4,6 +4,8 @@ import ( "context" "errors" "reflect" + + "github.com/lmika/gopkgs/fp/slices" ) type BuiltinHandler func(ctx context.Context, args CallArgs) (any, error) @@ -24,7 +26,7 @@ func (ca *CallArgs) Bind(vars ...interface{}) error { } for i, v := range vars { - if err := bindArg(v, ca.args.args[i]); err != nil { + if err := ca.bindArg(v, ca.args.args[i]); err != nil { return err } } @@ -72,7 +74,7 @@ func (ca CallArgs) BindSwitch(name string, val interface{}) error { return nil } - return bindArg(val, (*vars)[0]) + return ca.bindArg(val, (*vars)[0]) } func (inst *Inst) SetBuiltin(name string, fn BuiltinHandler) { @@ -92,10 +94,22 @@ func (u userBuiltin) invoke(ctx context.Context, args invocationArgs) (object, e return fromGoValue(v) } -func bindArg(v interface{}, arg object) error { +func (ca CallArgs) bindArg(v interface{}, arg object) error { switch t := v.(type) { case *interface{}: *t, _ = toGoValue(arg) + case *Invokable: + i, ok := arg.(invokable) + if !ok { + return errors.New("exepected invokable") + } + *t = Invokable{ + inv: i, + eval: ca.args.eval, + inst: ca.args.inst, + ec: ca.args.ec, + } + return nil case *string: *t = arg.String() case *int: @@ -194,3 +208,37 @@ func (m missingHandlerInvokable) invoke(ctx context.Context, args invocationArgs return fromGoValue(v) } + +type Invokable struct { + inv invokable + eval evaluator + inst *Inst + ec *evalCtx +} + +func (i Invokable) Invoke(ctx context.Context, args ...any) (any, error) { + var err error + invArgs := invocationArgs{ + eval: i.eval, + ec: i.ec, + inst: i.inst, + } + + invArgs.args, err = slices.MapWithError(args, func(a any) (object, error) { + return fromGoValue(a) + }) + if err != nil { + return nil, err + } + + res, err := i.inv.invoke(ctx, invArgs) + if err != nil { + return nil, err + } + + goRes, ok := toGoValue(res) + if !ok { + return nil, errors.New("cannot convert result to Go Value") + } + return goRes, err +} diff --git a/ucl/userbuiltin_test.go b/ucl/userbuiltin_test.go index 3778e76..3e6ada6 100644 --- a/ucl/userbuiltin_test.go +++ b/ucl/userbuiltin_test.go @@ -6,6 +6,7 @@ import ( "fmt" "strings" "testing" + "ucl.lmika.dev/ucl" "github.com/stretchr/testify/assert" @@ -265,6 +266,30 @@ func TestCallArgs_CanBind(t *testing.T) { assert.Equal(t, tt.want, res) }) } + + t.Run("can bind invokable", func(t *testing.T) { + inst := ucl.New() + inst.SetBuiltin("wrap", func(ctx context.Context, args ucl.CallArgs) (any, error) { + var inv ucl.Invokable + + if err := args.Bind(&inv); err != nil { + return nil, err + } + + res, err := inv.Invoke(ctx, "hello") + if err != nil { + return nil, err + } + + return fmt.Sprintf("[[%v]]", res), nil + }) + + ctx := context.Background() + + res, err := inst.Eval(ctx, `wrap { |x| toUpper $x }`) + assert.NoError(t, err) + assert.Equal(t, "[[HELLO]]", res) + }) } func TestCallArgs_MissingCommandHandler(t *testing.T) {