From abbae1a08f1a2c5373e698f9850b497347925ddb Mon Sep 17 00:00:00 2001
From: Leon Mika <lmika@lmika.org>
Date: Tue, 30 Apr 2024 20:55:06 +1000
Subject: [PATCH] Added break, continue, and return

---
 ucl/builtins.go          |  53 ++++++++++++-
 ucl/inst.go              |   3 +
 ucl/objs.go              |  20 +++++
 ucl/testbuiltins_test.go | 158 +++++++++++++++++++++++++++++++++++++++
 4 files changed, 230 insertions(+), 4 deletions(-)

diff --git a/ucl/builtins.go b/ucl/builtins.go
index 2e67655..a13544d 100644
--- a/ucl/builtins.go
+++ b/ucl/builtins.go
@@ -68,6 +68,10 @@ func eqBuiltin(ctx context.Context, args invocationArgs) (object, error) {
 		if rv, ok := r.(strObject); ok {
 			return boolObject(lv == rv), nil
 		}
+	case intObject:
+		if rv, ok := r.(intObject); ok {
+			return boolObject(lv == rv), nil
+		}
 	}
 	return boolObject(false), nil
 }
@@ -309,7 +313,10 @@ func foreachBuiltin(ctx context.Context, args macroArgs) (object, error) {
 		blockIdx = 0
 	}
 
-	var last object
+	var (
+		last     object
+		breakErr errBreak
+	)
 
 	switch t := items.(type) {
 	case listable:
@@ -318,14 +325,26 @@ func foreachBuiltin(ctx context.Context, args macroArgs) (object, error) {
 			v := t.Index(i)
 			last, err = args.evalBlock(ctx, blockIdx, []object{v}, true) // TO INCLUDE: the index
 			if err != nil {
-				return nil, err
+				if errors.As(err, &breakErr) {
+					if !breakErr.isCont {
+						return breakErr.ret, nil
+					}
+				} else {
+					return nil, err
+				}
 			}
 		}
 	case hashObject:
 		for k, v := range t {
 			last, err = args.evalBlock(ctx, blockIdx, []object{strObject(k), v}, true)
 			if err != nil {
-				return nil, err
+				if errors.As(err, &breakErr) {
+					if !breakErr.isCont {
+						return breakErr.ret, nil
+					}
+				} else {
+					return nil, err
+				}
 			}
 		}
 	}
@@ -333,6 +352,24 @@ func foreachBuiltin(ctx context.Context, args macroArgs) (object, error) {
 	return last, nil
 }
 
+func breakBuiltin(ctx context.Context, args invocationArgs) (object, error) {
+	if len(args.args) < 1 {
+		return nil, errBreak{}
+	}
+	return nil, errBreak{ret: args.args[0]}
+}
+
+func continueBuiltin(ctx context.Context, args invocationArgs) (object, error) {
+	return nil, errBreak{isCont: true}
+}
+
+func returnBuiltin(ctx context.Context, args invocationArgs) (object, error) {
+	if len(args.args) < 1 {
+		return nil, errReturn{}
+	}
+	return nil, errReturn{ret: args.args[0]}
+}
+
 func procBuiltin(ctx context.Context, args macroArgs) (object, error) {
 	if args.nargs() < 1 {
 		return nil, errors.New("need at least one arguments")
@@ -388,5 +425,13 @@ func (b procObject) invoke(ctx context.Context, args invocationArgs) (object, er
 		}
 	}
 
-	return b.eval.evalBlock(ctx, newEc, b.block)
+	res, err := b.eval.evalBlock(ctx, newEc, b.block)
+	if err != nil {
+		var er errReturn
+		if errors.As(err, &er) {
+			return er.ret, nil
+		}
+		return nil, err
+	}
+	return res, nil
 }
diff --git a/ucl/inst.go b/ucl/inst.go
index f29c14a..7493338 100644
--- a/ucl/inst.go
+++ b/ucl/inst.go
@@ -39,6 +39,9 @@ func New(opts ...InstOption) *Inst {
 
 	rootEC.addCmd("eq", invokableFunc(eqBuiltin))
 	rootEC.addCmd("cat", invokableFunc(concatBuiltin))
+	rootEC.addCmd("break", invokableFunc(breakBuiltin))
+	rootEC.addCmd("continue", invokableFunc(continueBuiltin))
+	rootEC.addCmd("return", invokableFunc(returnBuiltin))
 
 	rootEC.addMacro("if", macroFunc(ifBuiltin))
 	rootEC.addMacro("foreach", macroFunc(foreachBuiltin))
diff --git a/ucl/objs.go b/ucl/objs.go
index fbdfc2d..029984a 100644
--- a/ucl/objs.go
+++ b/ucl/objs.go
@@ -437,3 +437,23 @@ func (s structProxyObject) Each(fn func(k string, v object) error) error {
 	}
 	return nil
 }
+
+type errBreak struct {
+	isCont bool
+	ret    object
+}
+
+func (e errBreak) Error() string {
+	if e.isCont {
+		return "continue"
+	}
+	return "break"
+}
+
+type errReturn struct {
+	ret object
+}
+
+func (e errReturn) Error() string {
+	return "return"
+}
diff --git a/ucl/testbuiltins_test.go b/ucl/testbuiltins_test.go
index 31be850..11f6c8d 100644
--- a/ucl/testbuiltins_test.go
+++ b/ucl/testbuiltins_test.go
@@ -217,6 +217,91 @@ func TestBuiltins_ForEach(t *testing.T) {
 	}
 }
 
+func TestBuiltins_Break(t *testing.T) {
+	tests := []struct {
+		desc string
+		expr string
+		want string
+	}{
+		{desc: "break unconditionally returning nothing", expr: `
+			foreach ["1" "2" "3"] { |v|
+				break
+				echo $v
+			}`, want: "(nil)\n"},
+		{desc: "break conditionally returning nothing", expr: `
+			foreach ["1" "2" "3"] { |v|
+				echo $v
+				if (eq $v "2") { break }
+			}`, want: "1\n2\n(nil)\n"},
+		{desc: "break inner loop only returning nothing", expr: `
+			foreach ["a" "b"] { |u|
+				foreach ["1" "2" "3"] { |v|
+					echo $u $v
+					if (eq $v "2") { break }
+				}
+			}`, want: "a1\na2\nb1\nb2\n(nil)\n"},
+		{desc: "break returning value", expr: `
+			echo (foreach ["1" "2" "3"] { |v|
+				echo $v
+				if (eq $v "2") { break "hello" }
+			})`, want: "1\n2\nhello\n(nil)\n"},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.desc, func(t *testing.T) {
+			ctx := context.Background()
+			outW := bytes.NewBuffer(nil)
+
+			inst := New(WithOut(outW), WithTestBuiltin())
+			err := EvalAndDisplay(ctx, inst, tt.expr)
+
+			assert.NoError(t, err)
+			assert.Equal(t, tt.want, outW.String())
+		})
+	}
+}
+
+func TestBuiltins_Continue(t *testing.T) {
+	tests := []struct {
+		desc string
+		expr string
+		want string
+	}{
+		{desc: "continue unconditionally", expr: `
+			foreach ["1" "2" "3"] { |v|
+				echo $v "s"
+				continue
+				echo $v "e"
+			}`, want: "1s\n2s\n3s\n(nil)\n"},
+		{desc: "conditionally conditionally", expr: `
+			foreach ["1" "2" "3"] { |v|
+				echo $v "s"
+				if (eq $v "2") { continue }
+				echo $v "e"
+			}`, want: "1s\n1e\n2s\n3s\n3e\n(nil)\n"},
+		{desc: "continue inner loop only", expr: `
+			foreach ["a" "b"] { |u|
+				foreach ["1" "2" "3"] { |v|	
+					if (eq $v "2") { continue }
+					echo $u $v
+				}
+			}`, want: "a1\na3\nb1\nb3\n(nil)\n"},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.desc, func(t *testing.T) {
+			ctx := context.Background()
+			outW := bytes.NewBuffer(nil)
+
+			inst := New(WithOut(outW), WithTestBuiltin())
+			err := EvalAndDisplay(ctx, inst, tt.expr)
+
+			assert.NoError(t, err)
+			assert.Equal(t, tt.want, outW.String())
+		})
+	}
+}
+
 func TestBuiltins_Procs(t *testing.T) {
 	tests := []struct {
 		desc string
@@ -297,6 +382,79 @@ func TestBuiltins_Procs(t *testing.T) {
 	}
 }
 
+func TestBuiltins_Return(t *testing.T) {
+	tests := []struct {
+		desc string
+		expr string
+		want string
+	}{
+		{desc: "nil return", expr: `
+			proc greet {
+				echo "Hello"
+				return
+				echo "World"
+			}
+
+			greet
+			`, want: "Hello\n(nil)\n"},
+		{desc: "simple return", expr: `
+			proc greet {
+				return "Hello, world"
+				echo "But not me"
+			}
+
+			greet
+			`, want: "Hello, world\n"},
+		{desc: "only return current frame", expr: `
+			proc greetWhat {
+				echo "Greet the"
+				return "moon"
+				echo "world"
+			}
+			proc greet {
+				set what (greetWhat)
+				echo "Hello, " $what
+			}
+
+			greet
+			`, want: "Greet the\nHello, moon\n(nil)\n"},
+		{desc: "return in loop", expr: `
+			proc countdown { |nums|
+				foreach $nums { |n|
+					echo $n
+					if (eq $n 3) {
+						return "abort"
+					}
+				}
+			}
+			countdown [5 4 3 2 1]
+			`, want: "5\n4\n3\nabort\n"},
+		{desc: "recursive procs", expr: `
+			proc four4 { |xs|
+				if (eq $xs "xxxx") {
+					return $xs
+				}
+				four4 (cat $xs "x")	
+			}
+
+			four4
+			`, want: "xxxx\n"},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.desc, func(t *testing.T) {
+			ctx := context.Background()
+			outW := bytes.NewBuffer(nil)
+
+			inst := New(WithOut(outW), WithTestBuiltin())
+			err := EvalAndDisplay(ctx, inst, tt.expr)
+
+			assert.NoError(t, err)
+			assert.Equal(t, tt.want, outW.String())
+		})
+	}
+}
+
 func TestBuiltins_Map(t *testing.T) {
 	tests := []struct {
 		desc string