diff --git a/ucl/objs.go b/ucl/objs.go index 0bc520c..43aeec6 100644 --- a/ucl/objs.go +++ b/ucl/objs.go @@ -297,6 +297,9 @@ func (ia invocationArgs) fork(args []object) invocationArgs { } func (ia invocationArgs) shift(i int) invocationArgs { + if len(ia.args) < i { + return ia + } return invocationArgs{ eval: ia.eval, inst: ia.inst, diff --git a/ucl/userbuiltin.go b/ucl/userbuiltin.go index fd1a2b0..7811b8c 100644 --- a/ucl/userbuiltin.go +++ b/ucl/userbuiltin.go @@ -24,6 +24,23 @@ func (ca *CallArgs) Bind(vars ...interface{}) error { return nil } +func (ca *CallArgs) CanBind(vars ...interface{}) bool { + if len(ca.args.args) < len(vars) { + return false + } + + for i, v := range vars { + if !canBindArg(v, ca.args.args[i]) { + return false + } + } + return true +} + +func (ca *CallArgs) Shift(n int) { + ca.args = ca.args.shift(n) +} + func (ca CallArgs) IsTopLevel() bool { return ca.args.ec.parent == nil || ca.args.ec == ca.args.ec.root } @@ -71,6 +88,12 @@ func bindArg(v interface{}, arg object) error { switch t := v.(type) { case *string: *t = arg.String() + case *int: + if iArg, ok := arg.(intObject); ok { + *t = int(iArg) + } else { + return errors.New("invalid arg") + } } switch t := arg.(type) { @@ -85,6 +108,27 @@ func bindArg(v interface{}, arg object) error { return nil } +func canBindArg(v interface{}, arg object) bool { + switch v.(type) { + case *string: + return true + case *int: + _, ok := arg.(intObject) + return ok + } + + switch t := arg.(type) { + case proxyObject: + return canBindProxyObject(v, reflect.ValueOf(t.p)) + case listableProxyObject: + return canBindProxyObject(v, t.v) + case structProxyObject: + return canBindProxyObject(v, t.v) + } + + return true +} + func bindProxyObject(v interface{}, r reflect.Value) error { argValue := reflect.ValueOf(v) if argValue.Kind() != reflect.Ptr { @@ -103,3 +147,22 @@ func bindProxyObject(v interface{}, r reflect.Value) error { r = r.Elem() } } + +func canBindProxyObject(v interface{}, r reflect.Value) bool { + argValue := reflect.ValueOf(v) + if argValue.Kind() != reflect.Ptr { + return false + } + + for { + if r.Type().AssignableTo(argValue.Elem().Type()) { + argValue.Elem().Set(r) + return true + } + if r.Type().Kind() != reflect.Pointer { + return true + } + + r = r.Elem() + } +} diff --git a/ucl/userbuiltin_test.go b/ucl/userbuiltin_test.go index c4f2a9d..1e48320 100644 --- a/ucl/userbuiltin_test.go +++ b/ucl/userbuiltin_test.go @@ -214,6 +214,62 @@ func TestCallArgs_Bind(t *testing.T) { }) } +func TestCallArgs_CanBind(t *testing.T) { + t.Run("returns ture of all passed in arguments can be bound without consuming them", func(t *testing.T) { + tests := []struct { + descr string + eval string + want []string + }{ + {descr: "bind nothing", eval: `test`, want: []string{}}, + {descr: "bind one", eval: `test "yes"`, want: []string{"str"}}, + {descr: "bind two", eval: `test "yes" 213`, want: []string{"str", "int"}}, + {descr: "bind three", eval: `test "yes" 213 (proxy)`, want: []string{"all", "str", "int", "proxy"}}, + } + + for _, tt := range tests { + t.Run(tt.descr, func(t *testing.T) { + type proxyObj struct{} + + ctx := context.Background() + res := make([]string, 0) + + inst := ucl.New() + inst.SetBuiltin("proxy", func(ctx context.Context, args ucl.CallArgs) (any, error) { + return proxyObj{}, nil + }) + inst.SetBuiltin("test", func(ctx context.Context, args ucl.CallArgs) (any, error) { + var ( + s string + i int + p proxyObj + ) + + if args.CanBind(&s, &i, &p) { + res = append(res, "all") + } + if args.CanBind(&s) { + res = append(res, "str") + } + args.Shift(1) + if args.CanBind(&i) { + res = append(res, "int") + } + args.Shift(1) + if args.CanBind(&p) { + res = append(res, "proxy") + } + return nil, nil + }) + + _, err := inst.Eval(ctx, tt.eval) + assert.NoError(t, err) + assert.Equal(t, tt.want, res) + }) + } + }) +} + func TestCallArgs_IsTopLevel(t *testing.T) { t.Run("true if the command is running at the top-level frame", func(t *testing.T) { ctx := context.Background()