diff --git a/ucl/builtins.go b/ucl/builtins.go index 2e67655..a13544d 100644 --- a/ucl/builtins.go +++ b/ucl/builtins.go @@ -68,6 +68,10 @@ func eqBuiltin(ctx context.Context, args invocationArgs) (object, error) { if rv, ok := r.(strObject); ok { return boolObject(lv == rv), nil } + case intObject: + if rv, ok := r.(intObject); ok { + return boolObject(lv == rv), nil + } } return boolObject(false), nil } @@ -309,7 +313,10 @@ func foreachBuiltin(ctx context.Context, args macroArgs) (object, error) { blockIdx = 0 } - var last object + var ( + last object + breakErr errBreak + ) switch t := items.(type) { case listable: @@ -318,14 +325,26 @@ func foreachBuiltin(ctx context.Context, args macroArgs) (object, error) { v := t.Index(i) last, err = args.evalBlock(ctx, blockIdx, []object{v}, true) // TO INCLUDE: the index if err != nil { - return nil, err + if errors.As(err, &breakErr) { + if !breakErr.isCont { + return breakErr.ret, nil + } + } else { + return nil, err + } } } case hashObject: for k, v := range t { last, err = args.evalBlock(ctx, blockIdx, []object{strObject(k), v}, true) if err != nil { - return nil, err + if errors.As(err, &breakErr) { + if !breakErr.isCont { + return breakErr.ret, nil + } + } else { + return nil, err + } } } } @@ -333,6 +352,24 @@ func foreachBuiltin(ctx context.Context, args macroArgs) (object, error) { return last, nil } +func breakBuiltin(ctx context.Context, args invocationArgs) (object, error) { + if len(args.args) < 1 { + return nil, errBreak{} + } + return nil, errBreak{ret: args.args[0]} +} + +func continueBuiltin(ctx context.Context, args invocationArgs) (object, error) { + return nil, errBreak{isCont: true} +} + +func returnBuiltin(ctx context.Context, args invocationArgs) (object, error) { + if len(args.args) < 1 { + return nil, errReturn{} + } + return nil, errReturn{ret: args.args[0]} +} + func procBuiltin(ctx context.Context, args macroArgs) (object, error) { if args.nargs() < 1 { return nil, errors.New("need at least one arguments") @@ -388,5 +425,13 @@ func (b procObject) invoke(ctx context.Context, args invocationArgs) (object, er } } - return b.eval.evalBlock(ctx, newEc, b.block) + res, err := b.eval.evalBlock(ctx, newEc, b.block) + if err != nil { + var er errReturn + if errors.As(err, &er) { + return er.ret, nil + } + return nil, err + } + return res, nil } diff --git a/ucl/inst.go b/ucl/inst.go index f29c14a..7493338 100644 --- a/ucl/inst.go +++ b/ucl/inst.go @@ -39,6 +39,9 @@ func New(opts ...InstOption) *Inst { rootEC.addCmd("eq", invokableFunc(eqBuiltin)) rootEC.addCmd("cat", invokableFunc(concatBuiltin)) + rootEC.addCmd("break", invokableFunc(breakBuiltin)) + rootEC.addCmd("continue", invokableFunc(continueBuiltin)) + rootEC.addCmd("return", invokableFunc(returnBuiltin)) rootEC.addMacro("if", macroFunc(ifBuiltin)) rootEC.addMacro("foreach", macroFunc(foreachBuiltin)) diff --git a/ucl/objs.go b/ucl/objs.go index 43aeec6..47bd7cb 100644 --- a/ucl/objs.go +++ b/ucl/objs.go @@ -445,3 +445,23 @@ func (s structProxyObject) Each(fn func(k string, v object) error) error { } return nil } + +type errBreak struct { + isCont bool + ret object +} + +func (e errBreak) Error() string { + if e.isCont { + return "continue" + } + return "break" +} + +type errReturn struct { + ret object +} + +func (e errReturn) Error() string { + return "return" +} diff --git a/ucl/testbuiltins_test.go b/ucl/testbuiltins_test.go index 31be850..11f6c8d 100644 --- a/ucl/testbuiltins_test.go +++ b/ucl/testbuiltins_test.go @@ -217,6 +217,91 @@ func TestBuiltins_ForEach(t *testing.T) { } } +func TestBuiltins_Break(t *testing.T) { + tests := []struct { + desc string + expr string + want string + }{ + {desc: "break unconditionally returning nothing", expr: ` + foreach ["1" "2" "3"] { |v| + break + echo $v + }`, want: "(nil)\n"}, + {desc: "break conditionally returning nothing", expr: ` + foreach ["1" "2" "3"] { |v| + echo $v + if (eq $v "2") { break } + }`, want: "1\n2\n(nil)\n"}, + {desc: "break inner loop only returning nothing", expr: ` + foreach ["a" "b"] { |u| + foreach ["1" "2" "3"] { |v| + echo $u $v + if (eq $v "2") { break } + } + }`, want: "a1\na2\nb1\nb2\n(nil)\n"}, + {desc: "break returning value", expr: ` + echo (foreach ["1" "2" "3"] { |v| + echo $v + if (eq $v "2") { break "hello" } + })`, want: "1\n2\nhello\n(nil)\n"}, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + ctx := context.Background() + outW := bytes.NewBuffer(nil) + + inst := New(WithOut(outW), WithTestBuiltin()) + err := EvalAndDisplay(ctx, inst, tt.expr) + + assert.NoError(t, err) + assert.Equal(t, tt.want, outW.String()) + }) + } +} + +func TestBuiltins_Continue(t *testing.T) { + tests := []struct { + desc string + expr string + want string + }{ + {desc: "continue unconditionally", expr: ` + foreach ["1" "2" "3"] { |v| + echo $v "s" + continue + echo $v "e" + }`, want: "1s\n2s\n3s\n(nil)\n"}, + {desc: "conditionally conditionally", expr: ` + foreach ["1" "2" "3"] { |v| + echo $v "s" + if (eq $v "2") { continue } + echo $v "e" + }`, want: "1s\n1e\n2s\n3s\n3e\n(nil)\n"}, + {desc: "continue inner loop only", expr: ` + foreach ["a" "b"] { |u| + foreach ["1" "2" "3"] { |v| + if (eq $v "2") { continue } + echo $u $v + } + }`, want: "a1\na3\nb1\nb3\n(nil)\n"}, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + ctx := context.Background() + outW := bytes.NewBuffer(nil) + + inst := New(WithOut(outW), WithTestBuiltin()) + err := EvalAndDisplay(ctx, inst, tt.expr) + + assert.NoError(t, err) + assert.Equal(t, tt.want, outW.String()) + }) + } +} + func TestBuiltins_Procs(t *testing.T) { tests := []struct { desc string @@ -297,6 +382,79 @@ func TestBuiltins_Procs(t *testing.T) { } } +func TestBuiltins_Return(t *testing.T) { + tests := []struct { + desc string + expr string + want string + }{ + {desc: "nil return", expr: ` + proc greet { + echo "Hello" + return + echo "World" + } + + greet + `, want: "Hello\n(nil)\n"}, + {desc: "simple return", expr: ` + proc greet { + return "Hello, world" + echo "But not me" + } + + greet + `, want: "Hello, world\n"}, + {desc: "only return current frame", expr: ` + proc greetWhat { + echo "Greet the" + return "moon" + echo "world" + } + proc greet { + set what (greetWhat) + echo "Hello, " $what + } + + greet + `, want: "Greet the\nHello, moon\n(nil)\n"}, + {desc: "return in loop", expr: ` + proc countdown { |nums| + foreach $nums { |n| + echo $n + if (eq $n 3) { + return "abort" + } + } + } + countdown [5 4 3 2 1] + `, want: "5\n4\n3\nabort\n"}, + {desc: "recursive procs", expr: ` + proc four4 { |xs| + if (eq $xs "xxxx") { + return $xs + } + four4 (cat $xs "x") + } + + four4 + `, want: "xxxx\n"}, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + ctx := context.Background() + outW := bytes.NewBuffer(nil) + + inst := New(WithOut(outW), WithTestBuiltin()) + err := EvalAndDisplay(ctx, inst, tt.expr) + + assert.NoError(t, err) + assert.Equal(t, tt.want, outW.String()) + }) + } +} + func TestBuiltins_Map(t *testing.T) { tests := []struct { desc string