From d6cc449b407209fb962d92567a1be1269a8783a2 Mon Sep 17 00:00:00 2001 From: Leon Mika Date: Sat, 13 Apr 2024 21:46:50 +1000 Subject: [PATCH] Added macros and the if macro --- cmdlang/ast.go | 26 ++++++----- cmdlang/builtins.go | 59 +++++++++++++++++-------- cmdlang/env.go | 49 +++++++++++++------- cmdlang/eval.go | 49 ++++++++++++++++---- cmdlang/inst.go | 4 +- cmdlang/inst_test.go | 103 ++++++++++++++++++++++++++++++++++++++----- cmdlang/objs.go | 85 +++++++++++++++++++++++++++++++++++ cmdlang/streams.go | 12 +++++ 8 files changed, 323 insertions(+), 64 deletions(-) diff --git a/cmdlang/ast.go b/cmdlang/ast.go index 48b70c0..d5984aa 100644 --- a/cmdlang/ast.go +++ b/cmdlang/ast.go @@ -7,14 +7,19 @@ import ( ) type astLiteral struct { - Str *string `parser:"@String"` - Ident *string `parser:" | @Ident"` + Str *string `parser:"@String"` +} + +type astBlock struct { + Statements []*astStatements `parser:"LC NL? @@ NL? RC"` } type astCmdArg struct { Literal *astLiteral `parser:"@@"` + Ident *string `parser:"| @Ident"` Var *string `parser:"| DOLLAR @Ident"` Sub *astPipeline `parser:"| LP @@ RP"` + Block *astBlock `parser:"| @@"` } type astCmd struct { @@ -29,29 +34,30 @@ type astPipeline struct { type astStatements struct { First *astPipeline `parser:"@@"` - Rest []*astPipeline `parser:"( (SEMICL | NL)+ @@ )*"` // TODO: also add support for newlines + Rest []*astPipeline `parser:"( NL+ @@ )*"` // TODO: also add support for newlines } -type astBlock struct { - Statements *astStatements `parser:"'{' "` +type astScript struct { + Statements *astStatements `parser:"NL* @@ NL*"` } var scanner = lexer.MustStateful(lexer.Rules{ "Root": { - {"Whitespace", `[ ]`, nil}, - {"NL", `\n\s*`, nil}, + {"Whitespace", `[ \t]+`, nil}, {"String", `"(\\"|[^"])*"`, nil}, {"DOLLAR", `\$`, nil}, {"LP", `\(`, nil}, {"RP", `\)`, nil}, - {"SEMICL", `;`, nil}, + {"LC", `\{`, nil}, + {"RC", `\}`, nil}, + {"NL", `[;\n][; \n\t]*`, nil}, {"PIPE", `\|`, nil}, {"Ident", `\w+`, nil}, }, }) -var parser = participle.MustBuild[astStatements](participle.Lexer(scanner), +var parser = participle.MustBuild[astScript](participle.Lexer(scanner), participle.Elide("Whitespace")) -func parse(r io.Reader) (*astStatements, error) { +func parse(r io.Reader) (*astScript, error) { return parser.Parse("test", r) } diff --git a/cmdlang/builtins.go b/cmdlang/builtins.go index 546531f..2918dac 100644 --- a/cmdlang/builtins.go +++ b/cmdlang/builtins.go @@ -3,6 +3,7 @@ package cmdlang import ( "bufio" "context" + "errors" "fmt" "io" "os" @@ -14,6 +15,7 @@ func echoBuiltin(ctx context.Context, args invocationArgs) (object, error) { if _, err := fmt.Fprintln(args.inst.Out()); err != nil { return nil, err } + return nil, nil } var line strings.Builder @@ -83,6 +85,10 @@ func (f *fileLinesStream) String() string { return fmt.Sprintf("fileLinesStream{file: %v}", f.filename) } +func (f *fileLinesStream) Truthy() bool { + return true // ?? +} + func (f *fileLinesStream) next() (object, error) { var err error @@ -111,25 +117,40 @@ func (f *fileLinesStream) close() error { return nil } -/* -func errorTestBuiltin(ctx context.Context, inStream stream, args invocationArgs) (object, error) { - return &timeBombStream{inStream, 2}, nil -} - -type timeBombStream struct { - in stream - x int -} - -func (ms *timeBombStream) next() (object, error) { - if ms.x > 0 { - ms.x-- - return ms.in.next() +func ifBuiltin(ctx context.Context, args macroArgs) (object, error) { + if args.nargs() < 2 { + return nil, errors.New("need at least 2 arguments") } - return nil, errors.New("BOOM") -} -func (ms *timeBombStream) close() error { - return ms.in.close() + if guard, err := args.evalArg(ctx, 0); err == nil && isTruthy(guard) { + return args.evalBlock(ctx, 1) + } else if err != nil { + return nil, err + } + + args.shift(2) + for args.identIs(ctx, 0, "elif") { + args.shift(1) + + if args.nargs() < 2 { + return nil, errors.New("need at least 2 arguments") + } + + if guard, err := args.evalArg(ctx, 0); err == nil && isTruthy(guard) { + return args.evalBlock(ctx, 1) + } else if err != nil { + return nil, err + } + + args.shift(2) + } + + if args.identIs(ctx, 0, "else") && args.nargs() > 1 { + return args.evalBlock(ctx, 1) + } else if args.nargs() == 0 { + // no elif or else + return nil, nil + } + + return nil, errors.New("malformed if-elif-else") } -*/ diff --git a/cmdlang/env.go b/cmdlang/env.go index 947b931..5e35fd5 100644 --- a/cmdlang/env.go +++ b/cmdlang/env.go @@ -1,12 +1,9 @@ package cmdlang -import ( - "errors" -) - type evalCtx struct { parent *evalCtx commands map[string]invokable + macros map[string]macroable vars map[string]object } @@ -18,6 +15,14 @@ func (ec *evalCtx) addCmd(name string, inv invokable) { ec.commands[name] = inv } +func (ec *evalCtx) addMacro(name string, inv macroable) { + if ec.macros == nil { + ec.macros = make(map[string]macroable) + } + + ec.macros[name] = inv +} + func (ec *evalCtx) setVar(name string, val object) { if ec.vars == nil { ec.vars = make(map[string]object) @@ -39,16 +44,30 @@ func (ec *evalCtx) getVar(name string) (object, bool) { return nil, false } -func (ec *evalCtx) lookupCmd(name string) (invokable, error) { - for e := ec; e != nil; e = e.parent { - if e.commands == nil { - continue - } - - if cmd, ok := e.commands[name]; ok { - return cmd, nil - } - +func (ec *evalCtx) lookupInvokable(name string) invokable { + if ec == nil { + return nil } - return nil, errors.New("name " + name + " not found") + + for e := ec; e != nil; e = e.parent { + if cmd, ok := e.commands[name]; ok { + return cmd + } + } + + return ec.parent.lookupInvokable(name) +} + +func (ec *evalCtx) lookupMacro(name string) macroable { + if ec == nil { + return nil + } + + for e := ec; e != nil; e = e.parent { + if cmd, ok := e.macros[name]; ok { + return cmd + } + } + + return ec.parent.lookupMacro(name) } diff --git a/cmdlang/eval.go b/cmdlang/eval.go index 587bb13..77cd445 100644 --- a/cmdlang/eval.go +++ b/cmdlang/eval.go @@ -13,6 +13,22 @@ type evaluator struct { inst *Inst } +func (e evaluator) evalBlock(ctx context.Context, ec *evalCtx, n *astBlock) (lastRes object, err error) { + // TODO: push scope? + + for _, s := range n.Statements { + lastRes, err = e.evalStatement(ctx, ec, s) + if err != nil { + return nil, err + } + } + return lastRes, nil +} + +func (e evaluator) evalScript(ctx context.Context, ec *evalCtx, n *astScript) (lastRes object, err error) { + return e.evalStatement(ctx, ec, n.Statements) +} + func (e evaluator) evalStatement(ctx context.Context, ec *evalCtx, n *astStatements) (object, error) { res, err := e.evalPipeline(ctx, ec, n.First) if err != nil { @@ -60,11 +76,16 @@ func (e evaluator) evalPipeline(ctx context.Context, ec *evalCtx, n *astPipeline } func (e evaluator) evalCmd(ctx context.Context, ec *evalCtx, currentStream stream, ast *astCmd) (object, error) { - cmd, err := ec.lookupCmd(ast.Name) - if err != nil { - return nil, err + if cmd := ec.lookupInvokable(ast.Name); cmd != nil { + return e.evalInvokable(ctx, ec, currentStream, ast, cmd) + } else if macro := ec.lookupMacro(ast.Name); macro != nil { + return e.evalMacro(ctx, ec, currentStream, ast, macro) } + return nil, errors.New("unknown command") +} + +func (e evaluator) evalInvokable(ctx context.Context, ec *evalCtx, currentStream stream, ast *astCmd, cmd invokable) (object, error) { args, err := slices.MapWithError(ast.Args, func(a astCmdArg) (object, error) { return e.evalArg(ctx, ec, a) }) @@ -87,18 +108,30 @@ func (e evaluator) evalCmd(ctx context.Context, ec *evalCtx, currentStream strea return cmd.invoke(ctx, invArgs) } +func (e evaluator) evalMacro(ctx context.Context, ec *evalCtx, currentStream stream, ast *astCmd, cmd macroable) (object, error) { + return cmd.invokeMacro(ctx, macroArgs{ + eval: e, + ec: ec, + currentStream: currentStream, + ast: ast, + }) +} + func (e evaluator) evalArg(ctx context.Context, ec *evalCtx, n astCmdArg) (object, error) { switch { case n.Literal != nil: return e.evalLiteral(ctx, ec, n.Literal) + case n.Ident != nil: + return strObject(*n.Ident), nil case n.Var != nil: - v, ok := ec.getVar(*n.Var) - if !ok { - return nil, fmt.Errorf("unknown variable %s", *n.Var) + if v, ok := ec.getVar(*n.Var); ok { + return v, nil } - return v, nil + return nil, nil case n.Sub != nil: return e.evalSub(ctx, ec, n.Sub) + case n.Block != nil: + return blockObject{block: n.Block}, nil } return nil, errors.New("unhandled arg type") } @@ -111,8 +144,6 @@ func (e evaluator) evalLiteral(ctx context.Context, ec *evalCtx, n *astLiteral) return nil, err } return strObject(uq), nil - case n.Ident != nil: - return strObject(*n.Ident), nil } return nil, errors.New("unhandled literal type") } diff --git a/cmdlang/inst.go b/cmdlang/inst.go index 6b7ea5d..a1646c0 100644 --- a/cmdlang/inst.go +++ b/cmdlang/inst.go @@ -31,6 +31,8 @@ func New(opts ...InstOption) *Inst { rootEC.addCmd("toUpper", invokableStreamFunc(toUpperBuiltin)) rootEC.addCmd("cat", invokableFunc(catBuiltin)) + rootEC.addMacro("if", macroFunc(ifBuiltin)) + //rootEC.addCmd("testTimebomb", invokableStreamFunc(errorTestBuiltin)) rootEC.setVar("hello", strObject("world")) @@ -76,7 +78,7 @@ func (inst *Inst) eval(ctx context.Context, expr string) (object, error) { eval := evaluator{inst: inst} - return eval.evalStatement(ctx, inst.rootEC, ast) + return eval.evalScript(ctx, inst.rootEC, ast) } func (inst *Inst) EvalAndDisplay(ctx context.Context, expr string) error { diff --git a/cmdlang/inst_test.go b/cmdlang/inst_test.go index 7156bcf..96455df 100644 --- a/cmdlang/inst_test.go +++ b/cmdlang/inst_test.go @@ -5,7 +5,6 @@ import ( "context" "github.com/lmika/cmdlang-proto/cmdlang" "github.com/stretchr/testify/assert" - "strings" "testing" ) @@ -38,8 +37,6 @@ func TestInst_Eval(t *testing.T) { {desc: "multi 1", expr: `firstarg "hello" ; firstarg "world"`, want: "world"}, {desc: "multi 2", expr: `pipe "hello" | toUpper ; firstarg "world"`, want: "world"}, // TODO: assert for leaks {desc: "multi 3", expr: `set new "this is new" ; firstarg $new`, want: "this is new"}, - - {desc: "multi-line 1", expr: "echo \"Hello\" \n echo \"world\"", want: "world"}, } for _, tt := range tests { @@ -56,7 +53,7 @@ func TestInst_Eval(t *testing.T) { } } -func TestInst_Builtins_Echo(t *testing.T) { +func TestBuiltins_Echo(t *testing.T) { tests := []struct { desc string expr string @@ -65,8 +62,98 @@ func TestInst_Builtins_Echo(t *testing.T) { {desc: "no args", expr: `echo`, want: "\n"}, {desc: "single arg", expr: `echo "hello"`, want: "hello\n"}, {desc: "dual args", expr: `echo "hello " "world"`, want: "hello world\n"}, - {desc: "args to singleton stream", expr: `echo "aye" "bee" "see" | toUpper`, want: "AYEBEESEE\n"}, - {desc: "multi-line 1", expr: joinLines(`echo "Hello"`, `echo "world"`), want: "Hello\nworld"}, + {desc: "multi-line 1", expr: ` + echo "Hello" + echo "world" + `, want: "Hello\nworld\n"}, + {desc: "multi-line 2", expr: ` + echo "Hello" + + + echo "world" + `, want: "Hello\nworld\n"}, + {desc: "multi-line 3", expr: ` + +;;; + echo "Hello" +; + + echo "world" +; + `, want: "Hello\nworld\n"}, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + ctx := context.Background() + outW := bytes.NewBuffer(nil) + + inst := cmdlang.New(cmdlang.WithOut(outW), cmdlang.WithTestBuiltin()) + res, err := inst.Eval(ctx, tt.expr) + + assert.NoError(t, err) + assert.Nil(t, res) + assert.Equal(t, tt.want, outW.String()) + }) + } +} + +func TestBuiltins_If(t *testing.T) { + tests := []struct { + desc string + expr string + want string + }{ + {desc: "single then", expr: ` + set x "Hello" + if $x { + echo "true" + }`, want: "true\n(nil)\n"}, + {desc: "single then and else", expr: ` + set x "Hello" + if $x { + echo "true" + } else { + echo "false" + }`, want: "true\n(nil)\n"}, + {desc: "single then, elif and else", expr: ` + set x "Hello" + if $y { + echo "y is true" + } elif $x { + echo "x is true" + } else { + echo "nothings x" + }`, want: "x is true\n(nil)\n"}, + {desc: "single then and elif, no else", expr: ` + set x "Hello" + if $y { + echo "y is true" + } elif $x { + echo "x is true" + }`, want: "x is true\n(nil)\n"}, + {desc: "single then, two elif, and else", expr: ` + set x "Hello" + if $z { + echo "z is true" + } elif $y { + echo "y is true" + } elif $x { + echo "x is true" + }`, want: "x is true\n(nil)\n"}, + {desc: "single then, two elif, and else, expecting else", expr: ` + if $z { + echo "z is true" + } elif $y { + echo "y is true" + } elif $x { + echo "x is true" + } else { + echo "none is true" + }`, want: "none is true\n(nil)\n"}, + {desc: "compressed then", expr: `set x "Hello" ; if $x { echo "true" }`, want: "true\n(nil)\n"}, + {desc: "compressed else", expr: `if $x { echo "true" } else { echo "false" }`, want: "false\n(nil)\n"}, + {desc: "compressed if", expr: `if $x { echo "x" } elif $y { echo "y" } else { echo "false" }`, want: "false\n(nil)\n"}, } for _, tt := range tests { @@ -82,7 +169,3 @@ func TestInst_Builtins_Echo(t *testing.T) { }) } } - -func joinLines(ls ...string) string { - return strings.Join(ls, "\n") -} diff --git a/cmdlang/objs.go b/cmdlang/objs.go index 77ea948..45f26f6 100644 --- a/cmdlang/objs.go +++ b/cmdlang/objs.go @@ -9,6 +9,7 @@ import ( type object interface { String() string + Truthy() bool } type strObject string @@ -17,6 +18,10 @@ func (s strObject) String() string { return string(s) } +func (s strObject) Truthy() bool { + return string(s) != "" +} + func toGoValue(obj object) (interface{}, bool) { switch v := obj.(type) { case nil: @@ -28,6 +33,57 @@ func toGoValue(obj object) (interface{}, bool) { return nil, false } +type macroArgs struct { + eval evaluator + ec *evalCtx + currentStream stream + ast *astCmd + argShift int +} + +func (ma macroArgs) nargs() int { + return len(ma.ast.Args[ma.argShift:]) +} + +func (ma *macroArgs) shift(n int) { + ma.argShift += n +} + +func (ma macroArgs) identIs(ctx context.Context, n int, expectedIdent string) bool { + if n >= len(ma.ast.Args[ma.argShift:]) { + return false + } + + lit := ma.ast.Args[ma.argShift+n].Ident + if lit == nil { + return false + } + + return *lit == expectedIdent +} + +func (ma macroArgs) evalArg(ctx context.Context, n int) (object, error) { + if n >= len(ma.ast.Args[ma.argShift:]) { + return nil, errors.New("not enough arguments") // FIX + } + + return ma.eval.evalArg(ctx, ma.ec, ma.ast.Args[ma.argShift+n]) +} + +func (ma macroArgs) evalBlock(ctx context.Context, n int) (object, error) { + obj, err := ma.evalArg(ctx, n) + if err != nil { + return nil, err + } + + block, ok := obj.(blockObject) + if !ok { + return nil, errors.New("not a block object") + } + + return ma.eval.evalBlock(ctx, ma.ec, block.block) +} + type invocationArgs struct { inst *Inst ec *evalCtx @@ -58,6 +114,10 @@ type invokable interface { invoke(ctx context.Context, args invocationArgs) (object, error) } +type macroable interface { + invokeMacro(ctx context.Context, args macroArgs) (object, error) +} + type streamInvokable interface { invokable invokeWithStream(context.Context, stream, invocationArgs) (object, error) @@ -78,3 +138,28 @@ func (i invokableStreamFunc) invoke(ctx context.Context, args invocationArgs) (o func (i invokableStreamFunc) invokeWithStream(ctx context.Context, inStream stream, args invocationArgs) (object, error) { return i(ctx, inStream, args) } + +type blockObject struct { + block *astBlock +} + +func (bo blockObject) String() string { + return "block" +} + +func (bo blockObject) Truthy() bool { + return len(bo.block.Statements) > 0 +} + +type macroFunc func(ctx context.Context, args macroArgs) (object, error) + +func (i macroFunc) invokeMacro(ctx context.Context, args macroArgs) (object, error) { + return i(ctx, args) +} + +func isTruthy(obj object) bool { + if obj == nil { + return false + } + return obj.Truthy() +} diff --git a/cmdlang/streams.go b/cmdlang/streams.go index 5fd71f9..be460a3 100644 --- a/cmdlang/streams.go +++ b/cmdlang/streams.go @@ -74,6 +74,10 @@ func (s *singletonStream) String() string { return s.t.String() } +func (s *singletonStream) Truthy() bool { + return !s.consumed +} + func (s *singletonStream) next() (object, error) { if s.consumed { return nil, io.EOF @@ -93,6 +97,10 @@ func (s *listIterStream) String() string { return fmt.Sprintf("listIterStream{list: %v}", s.list) } +func (s *listIterStream) Truthy() bool { + return len(s.list) > s.cusr +} + func (s *listIterStream) next() (o object, err error) { if s.cusr >= len(s.list) { return nil, io.EOF @@ -115,6 +123,10 @@ func (ms mapFilterStream) String() string { return fmt.Sprintf("mapFilterStream{in: %v}", ms.in) } +func (ms mapFilterStream) Truthy() bool { + return true // ??? +} + func (ms mapFilterStream) next() (object, error) { for { u, err := ms.in.next()