Made table information available to scripts (#49)

- Added a property with table information to session and result set Script types
- Added the ability to add new key bindings to the script
- Rebuilt the foreground job dispatcher to reduce the occurrence of the progress indicator showing up when no job was running.
- Fixed rebinding of keys. Rebinding a key will no longer clear other keys for the old or new bindings.
This commit is contained in:
Leon Mika 2023-02-22 21:53:05 +11:00 committed by GitHub
parent 733e59ec95
commit 3f1aec2c87
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 710 additions and 115 deletions

View file

@ -114,7 +114,7 @@ func main() {
scriptController := controllers.NewScriptController(scriptManagerService, tableReadController, settingsController, eventBus)
keyBindingService := keybindings_service.NewService(keyBindings)
keyBindingController := controllers.NewKeyBindingController(keyBindingService)
keyBindingController := controllers.NewKeyBindingController(keyBindingService, scriptController)
commandController := commandctrl.NewCommandController(inputHistoryService)
commandController.AddCommandLookupExtension(scriptController)

View file

@ -3,6 +3,7 @@ package controllers
import (
"context"
"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
tea "github.com/charmbracelet/bubbletea"
"github.com/lmika/audax/internal/dynamo-browse/models"
"io/fs"
)
@ -25,3 +26,10 @@ type SettingsProvider interface {
SetScriptLookupPaths(value string) error
ScriptLookupPaths() string
}
type CustomKeyBindingSource interface {
LookupBinding(theKey string) string
CustomKeyCommand(key string) tea.Cmd
UnbindKey(key string)
Rebind(bindingName string, newKey string) error
}

View file

@ -4,6 +4,7 @@ import (
"context"
tea "github.com/charmbracelet/bubbletea"
"github.com/lmika/audax/internal/common/ui/events"
"github.com/lmika/audax/internal/dynamo-browse/services/jobs"
)
func NewJob[T any](jc *JobsController, description string, job func(ctx context.Context) (T, error)) JobBuilder[T] {
@ -61,7 +62,7 @@ func (jb JobBuilder[T]) executeJob(ctx context.Context) tea.Msg {
}
func (jb JobBuilder[T]) doSubmit() tea.Msg {
jb.jc.service.SubmitForegroundJob(func(ctx context.Context) {
if err := jb.jc.service.SubmitForegroundJob(jobs.WithDescription(jb.description, jobs.JobFunc(func(ctx context.Context) {
msg := jb.executeJob(ctx)
jb.jc.msgSender(msg)
@ -73,12 +74,9 @@ func (jb JobBuilder[T]) doSubmit() tea.Msg {
JobStatus: "",
})
}
}, func(msg string) {
jb.jc.msgSender(events.ForegroundJobUpdate{
JobRunning: true,
JobStatus: jb.description + " " + msg,
})
})
}))); err != nil {
return events.Error(err)
}
return events.ForegroundJobUpdate{
JobRunning: true,

View file

@ -5,6 +5,7 @@ import (
"github.com/lmika/audax/internal/common/ui/events"
"github.com/lmika/audax/internal/dynamo-browse/services/jobs"
bus "github.com/lmika/events"
"log"
)
type JobsController struct {
@ -18,6 +19,9 @@ func NewJobsController(service *jobs.Services, bus *bus.Bus, immediate bool) *Jo
service: service,
immediate: immediate,
}
bus.On(jobs.JobStartEvent, func(job jobs.EventData) { jc.sendForegroundJobState(job.Job, "") })
bus.On(jobs.JobIdleEvent, func() { jc.sendForegroundJobState(nil, "") })
bus.On(jobs.JobUpdateEvent, func(job jobs.EventData, update string) { jc.sendForegroundJobState(job.Job, update) })
return jc
}
@ -36,3 +40,30 @@ func (js *JobsController) CancelRunningJob(ifNoJobsRunning func() tea.Msg) tea.M
}
return ifNoJobsRunning()
}
func (jc *JobsController) sendForegroundJobState(job jobs.Job, update string) {
if job == nil {
log.Printf("job service idle")
jc.msgSender(events.ForegroundJobUpdate{
JobRunning: false,
})
return
}
var statusMessage string
if dj, ok := job.(jobs.DescribableJob); ok {
statusMessage = dj.Description
} else {
statusMessage = "Working…"
}
if len(update) > 0 {
statusMessage += " " + update
}
log.Printf("job update: %v", statusMessage)
jc.msgSender(events.ForegroundJobUpdate{
JobRunning: true,
JobStatus: statusMessage,
})
}

View file

@ -10,31 +10,79 @@ import (
type KeyBindingController struct {
service *keybindings.Service
customBindingSource CustomKeyBindingSource
}
func NewKeyBindingController(service *keybindings.Service) *KeyBindingController {
return &KeyBindingController{service: service}
func NewKeyBindingController(service *keybindings.Service, customBindingSource CustomKeyBindingSource) *KeyBindingController {
return &KeyBindingController{
service: service,
customBindingSource: customBindingSource,
}
}
func (kb *KeyBindingController) Rebind(bindingName string, newKey string, force bool) tea.Msg {
err := kb.service.Rebind(bindingName, newKey, force)
if err == nil {
existingBinding := kb.findExistingBinding(newKey)
if existingBinding == "" {
if err := kb.rebind(bindingName, newKey); err != nil {
return events.Error(err)
}
return events.StatusMsg(fmt.Sprintf("Binding '%v' now bound to '%v'", bindingName, newKey))
} else if force {
return events.Error(errors.Wrapf(err, "cannot bind '%v' to '%v'", bindingName, newKey))
}
var keyAlreadyBoundErr keybindings.KeyAlreadyBoundError
if errors.As(err, &keyAlreadyBoundErr) {
promptMsg := fmt.Sprintf("Key '%v' already bound to '%v'. Continue? ", keyAlreadyBoundErr.Key, keyAlreadyBoundErr.ExistingBindingName)
//err := kb.rebind(bindingName, newKey, force)
//if err == nil {
// return events.StatusMsg(fmt.Sprintf("Binding '%v' now bound to '%v'", bindingName, newKey))
//} else if force {
// return events.Error(errors.Wrapf(err, "cannot bind '%v' to '%v'", bindingName, newKey))
//}
//
//var keyAlreadyBoundErr keybindings.KeyAlreadyBoundError
//if errors.As(err, &keyAlreadyBoundErr) {
promptMsg := fmt.Sprintf("Key '%v' already bound to '%v'. Continue? ", newKey, existingBinding)
return events.ConfirmYes(promptMsg, func() tea.Msg {
err := kb.service.Rebind(bindingName, newKey, true)
kb.unbindKey(newKey)
err := kb.rebind(bindingName, newKey)
if err != nil {
return events.Error(err)
}
return events.StatusMsg(fmt.Sprintf("Binding '%v' now bound to '%v'", bindingName, newKey))
})
//}
//return events.Error(err)
}
return events.Error(err)
func (kb *KeyBindingController) rebind(bindingName string, newKey string) error {
err := kb.service.Rebind(bindingName, newKey)
if err == nil {
return nil
}
var invalidBinding keybindings.InvalidBindingError
if !errors.As(err, &invalidBinding) {
return err
}
return kb.customBindingSource.Rebind(bindingName, newKey)
}
func (kb *KeyBindingController) unbindKey(key string) {
kb.service.UnbindKey(key)
kb.customBindingSource.UnbindKey(key)
}
func (kb *KeyBindingController) findExistingBinding(key string) string {
if binding := kb.service.LookupBinding(key); binding != "" {
return binding
}
return kb.customBindingSource.LookupBinding(key)
}
func (kb *KeyBindingController) LookupCustomBinding(key string) tea.Cmd {
if kb.customBindingSource == nil {
return nil
}
return kb.customBindingSource.CustomKeyCommand(key)
}

View file

@ -214,3 +214,33 @@ func (s *sessionImpl) Query(ctx context.Context, query string, opts scriptmanage
}
return newResultSet, nil
}
func (sc *ScriptController) CustomKeyCommand(key string) tea.Cmd {
_, cmd := sc.scriptManager.LookupKeyBinding(key)
if cmd == nil {
return nil
}
return func() tea.Msg {
errChan := sc.waitAndPrintScriptError()
ctx := context.Background()
if err := cmd.Invoke(ctx, nil, errChan); err != nil {
return events.Error(err)
}
return nil
}
}
func (sc *ScriptController) Rebind(bindingName string, newKey string) error {
return sc.scriptManager.RebindKeyBinding(bindingName, newKey)
}
func (sc *ScriptController) LookupBinding(theKey string) string {
bindingName, _ := sc.scriptManager.LookupKeyBinding(theKey)
return bindingName
}
func (sc *ScriptController) UnbindKey(key string) {
sc.scriptManager.UnbindKey(key)
}

View file

@ -1,9 +0,0 @@
package jobs
const (
JobEventForegroundDone = "job_foreground_done"
)
type JobDoneEvent struct {
Err error
}

View file

@ -3,65 +3,60 @@ package jobs
import (
"context"
bus "github.com/lmika/events"
"github.com/pkg/errors"
"sync"
)
type Job func(ctx context.Context)
type jobInfo struct {
ctx context.Context
job Job
cancelFn func()
}
type Services struct {
bus *bus.Bus
jobQueue chan Job
mutex *sync.Mutex
foregroundJob *jobInfo
}
func NewService(bus *bus.Bus) *Services {
return &Services{
jc := &Services{
bus: bus,
jobQueue: make(chan Job, 10),
mutex: new(sync.Mutex),
}
go jc.waitForJobs()
return jc
}
// SubmitForegroundJob starts a foreground job.
func (jc *Services) SubmitForegroundJob(job Job, onJobUpdate func(msg string)) {
// TODO: if there's already a foreground job, then return error
ctx, cancelFn := context.WithCancel(context.Background())
jobUpdateChan := make(chan string)
jobUpdater := &jobUpdaterValue{msgUpdate: jobUpdateChan}
ctx = context.WithValue(ctx, jobUpdaterKey, jobUpdater)
newJobInfo := &jobInfo{
ctx: ctx,
cancelFn: cancelFn,
func (jc *Services) SubmitForegroundJob(job Job) error {
select {
case jc.jobQueue <- job:
return nil
default:
return errors.New("too many jobs queued")
}
// TODO: needs to be protected by the mutex
}
func (jc *Services) setForegroundJob(newJobInfo *jobInfo) {
jc.mutex.Lock()
jc.foregroundJob = newJobInfo
jc.mutex.Unlock()
go func() {
defer cancelFn()
defer close(jobUpdateChan)
job(newJobInfo.ctx)
// TODO: needs to be protected by the mutex
jc.foregroundJob = nil
}()
go func() {
for update := range jobUpdateChan {
onJobUpdate(update)
if newJobInfo != nil {
jc.bus.Fire(JobStartEvent, EventData{Job: newJobInfo.job})
} else {
jc.bus.Fire(JobIdleEvent)
}
}()
}
func (jc *Services) CancelForegroundJob() bool {
jc.mutex.Lock()
defer jc.mutex.Unlock()
// TODO: needs to be protected by the mutex
if jc.foregroundJob != nil {
// A nil cancel for a non-nil foreground job indicates that the cancellation function
@ -77,3 +72,46 @@ func (jc *Services) CancelForegroundJob() bool {
return false
}
func (jc *Services) waitForJobs() {
ctx := context.Background()
for job := range jc.jobQueue {
jc.runJob(ctx, job)
if len(jc.jobQueue) == 0 {
jc.setForegroundJob(nil)
}
}
}
func (jc *Services) runJob(ctx context.Context, job Job) {
ctx, cancelFn := context.WithCancel(context.Background())
defer cancelFn()
updateCloseChan := make(chan struct{})
jobUpdateChan := make(chan string)
jobUpdater := &jobUpdaterValue{msgUpdate: jobUpdateChan}
ctx = context.WithValue(ctx, jobUpdaterKey, jobUpdater)
newJobInfo := &jobInfo{
job: job,
ctx: ctx,
cancelFn: cancelFn,
}
jc.setForegroundJob(newJobInfo)
go func() {
defer close(updateCloseChan)
for update := range jobUpdateChan {
jc.bus.Fire(JobUpdateEvent, EventData{Job: job}, update)
}
}()
job.Execute(newJobInfo.ctx)
close(jobUpdateChan)
<-updateCloseChan
}

View file

@ -0,0 +1,32 @@
package jobs
import "context"
const (
JobStartEvent = "jobs.start"
JobIdleEvent = "jobs.idle"
JobUpdateEvent = "jobs.update"
)
type EventData struct {
Job Job
}
type Job interface {
Execute(ctx context.Context)
}
type JobFunc func(ctx context.Context)
func (jf JobFunc) Execute(ctx context.Context) {
jf(ctx)
}
func WithDescription(description string, job Job) Job {
return DescribableJob{job, description}
}
type DescribableJob struct {
Job
Description string
}

View file

@ -12,3 +12,9 @@ type KeyAlreadyBoundError struct {
func (e KeyAlreadyBoundError) Error() string {
return fmt.Sprintf("key '%v' already bound to '%v'", e.Key, e.ExistingBindingName)
}
type InvalidBindingError string
func (e InvalidBindingError) Error() string {
return fmt.Sprintf("invalid binding: %v", string(e))
}

View file

@ -2,8 +2,6 @@ package keybindings
import (
"github.com/charmbracelet/bubbles/key"
"github.com/pkg/errors"
"log"
"reflect"
"strings"
)
@ -23,37 +21,56 @@ func NewService(keyBinding any) *Service {
}
}
func (s *Service) Rebind(name string, newKey string, force bool) error {
// Check if there already exists a binding (or clear it)
func (s *Service) LookupBinding(theKey string) string {
var foundBinding = ""
s.walkBindingFields(func(bindingName string, binding *key.Binding) bool {
for _, boundKey := range binding.Keys() {
if boundKey == newKey {
if force {
// TODO: only filter out "boundKey" rather clear
log.Printf("clearing binding of %v", bindingName)
*binding = key.NewBinding()
return true
} else {
if boundKey == theKey {
foundBinding = bindingName
return false
}
}
return true
})
return foundBinding
}
func (s *Service) UnbindKey(theKey string) {
s.walkBindingFields(func(bindingName string, binding *key.Binding) bool {
for _, boundKey := range binding.Keys() {
if boundKey == theKey {
l := len(binding.Keys())
if l == 1 {
*binding = key.NewBinding()
} else if l > 1 {
newKeys := make([]string, 0)
for _, k := range binding.Keys() {
if k != theKey {
newKeys = append(newKeys, k)
}
}
*binding = key.NewBinding(key.WithKeys(newKeys...))
}
}
}
return true
})
if foundBinding != "" {
return KeyAlreadyBoundError{Key: newKey, ExistingBindingName: foundBinding}
}
func (s *Service) Rebind(name string, newKey string) error {
// Rebind
binding := s.findFieldForBinding(name)
if binding == nil {
return errors.Errorf("invalid binding: %v", name)
return InvalidBindingError(name)
}
if len(binding.Keys()) == 0 {
*binding = key.NewBinding(key.WithKeys(newKey))
} else {
newKeys := append([]string{newKey}, binding.Keys()...)
*binding = key.NewBinding(key.WithKeys(newKeys...))
}
return nil
}

View file

@ -1,4 +1,4 @@
// Code generated by mockery v2.16.0. DO NOT EDIT.
// Code generated by mockery v2.20.0. DO NOT EDIT.
package mocks
@ -29,6 +29,10 @@ func (_m *SessionService) Query(ctx context.Context, expr string, queryOptions s
ret := _m.Called(ctx, expr, queryOptions)
var r0 *models.ResultSet
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, scriptmanager.QueryOptions) (*models.ResultSet, error)); ok {
return rf(ctx, expr, queryOptions)
}
if rf, ok := ret.Get(0).(func(context.Context, string, scriptmanager.QueryOptions) *models.ResultSet); ok {
r0 = rf(ctx, expr, queryOptions)
} else {
@ -37,7 +41,6 @@ func (_m *SessionService) Query(ctx context.Context, expr string, queryOptions s
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, string, scriptmanager.QueryOptions) error); ok {
r1 = rf(ctx, expr, queryOptions)
} else {
@ -72,6 +75,11 @@ func (_c *SessionService_Query_Call) Return(_a0 *models.ResultSet, _a1 error) *S
return _c
}
func (_c *SessionService_Query_Call) RunAndReturn(run func(context.Context, string, scriptmanager.QueryOptions) (*models.ResultSet, error)) *SessionService_Query_Call {
_c.Call.Return(run)
return _c
}
// ResultSet provides a mock function with given fields: ctx
func (_m *SessionService) ResultSet(ctx context.Context) *models.ResultSet {
ret := _m.Called(ctx)
@ -111,6 +119,11 @@ func (_c *SessionService_ResultSet_Call) Return(_a0 *models.ResultSet) *SessionS
return _c
}
func (_c *SessionService_ResultSet_Call) RunAndReturn(run func(context.Context) *models.ResultSet) *SessionService_ResultSet_Call {
_c.Call.Return(run)
return _c
}
// SelectedItemIndex provides a mock function with given fields: ctx
func (_m *SessionService) SelectedItemIndex(ctx context.Context) int {
ret := _m.Called(ctx)
@ -148,6 +161,11 @@ func (_c *SessionService_SelectedItemIndex_Call) Return(_a0 int) *SessionService
return _c
}
func (_c *SessionService_SelectedItemIndex_Call) RunAndReturn(run func(context.Context) int) *SessionService_SelectedItemIndex_Call {
_c.Call.Return(run)
return _c
}
// SetResultSet provides a mock function with given fields: ctx, newResultSet
func (_m *SessionService) SetResultSet(ctx context.Context, newResultSet *models.ResultSet) {
_m.Called(ctx, newResultSet)
@ -177,6 +195,11 @@ func (_c *SessionService_SetResultSet_Call) Return() *SessionService_SetResultSe
return _c
}
func (_c *SessionService_SetResultSet_Call) RunAndReturn(run func(context.Context, *models.ResultSet)) *SessionService_SetResultSet_Call {
_c.Call.Return(run)
return _c
}
type mockConstructorTestingTNewSessionService interface {
mock.TestingT
Cleanup(func())

View file

@ -1,4 +1,4 @@
// Code generated by mockery v2.16.0. DO NOT EDIT.
// Code generated by mockery v2.20.0. DO NOT EDIT.
package mocks
@ -50,6 +50,11 @@ func (_c *UIService_PrintMessage_Call) Return() *UIService_PrintMessage_Call {
return _c
}
func (_c *UIService_PrintMessage_Call) RunAndReturn(run func(context.Context, string)) *UIService_PrintMessage_Call {
_c.Call.Return(run)
return _c
}
// Prompt provides a mock function with given fields: ctx, msg
func (_m *UIService) Prompt(ctx context.Context, msg string) chan string {
ret := _m.Called(ctx, msg)
@ -90,6 +95,11 @@ func (_c *UIService_Prompt_Call) Return(_a0 chan string) *UIService_Prompt_Call
return _c
}
func (_c *UIService_Prompt_Call) RunAndReturn(run func(context.Context, string) chan string) *UIService_Prompt_Call {
_c.Call.Return(run)
return _c
}
type mockConstructorTestingTNewUIService interface {
mock.TestingT
Cleanup(func())

View file

@ -2,10 +2,16 @@ package scriptmanager
import (
"context"
"fmt"
"github.com/cloudcmds/tamarin/arg"
"github.com/cloudcmds/tamarin/object"
"github.com/cloudcmds/tamarin/scope"
"github.com/pkg/errors"
"regexp"
)
var (
validKeyBindingNames = regexp.MustCompile(`^[-a-zA-Z0-9_]+$`)
)
type extModule struct {
@ -18,6 +24,7 @@ func (m *extModule) register(scp *scope.Scope) {
modScope.AddBuiltins([]*object.Builtin{
object.NewBuiltin("command", m.command, mod),
object.NewBuiltin("key_binding", m.keyBinding, mod),
})
scp.Declare("ext", mod, true)
@ -65,3 +72,64 @@ func (m *extModule) command(ctx context.Context, args ...object.Object) object.O
m.scriptPlugin.definedCommands[cmdName] = &Command{plugin: m.scriptPlugin, cmdFn: newCommand}
return nil
}
func (m *extModule) keyBinding(ctx context.Context, args ...object.Object) object.Object {
if err := arg.Require("ext.key_binding", 3, args); err != nil {
return err
}
bindingName, err := object.AsString(args[0])
if err != nil {
return err
} else if !validKeyBindingNames.MatchString(bindingName) {
return object.NewError(errors.New("value error: binding name must match regexp [-a-zA-Z0-9_]+"))
}
options, err := object.AsMap(args[1])
if err != nil {
return err
}
var defaultKey string
if strVal, isStrVal := options.Get("default").(*object.String); isStrVal {
defaultKey = strVal.Value()
}
fnRes, isFnRes := args[2].(*object.Function)
if !isFnRes {
return object.NewError(errors.New("expected second arg to be a function"))
}
callFn, hasCallFn := object.GetCallFunc(ctx)
if !hasCallFn {
return object.NewError(errors.New("no callFn found in context"))
}
// This command function will be executed by the script scheduler
newCommand := func(ctx context.Context, args []string) error {
objArgs := make([]object.Object, len(args))
for i, a := range args {
objArgs[i] = object.NewString(a)
}
ctx = ctxWithOptions(ctx, m.scriptPlugin.scriptService.options)
res := callFn(ctx, fnRes.Scope(), fnRes, objArgs)
if object.IsError(res) {
errObj := res.(*object.Error)
return errors.Errorf("command error '%v':%v - %v", m.scriptPlugin.name, bindingName, errObj.Inspect())
}
return nil
}
fullBindingName := fmt.Sprintf("ext.%v.%v", m.scriptPlugin.name, bindingName)
if m.scriptPlugin.definedKeyBindings == nil {
m.scriptPlugin.definedKeyBindings = make(map[string]*Command)
m.scriptPlugin.keyToKeyBinding = make(map[string]string)
}
m.scriptPlugin.definedKeyBindings[fullBindingName] = &Command{plugin: m.scriptPlugin, cmdFn: newCommand}
m.scriptPlugin.keyToKeyBinding[defaultKey] = fullBindingName
return nil
}

View file

@ -33,8 +33,15 @@ func (um *sessionModule) query(ctx context.Context, args ...object.Object) objec
}
// Table name
if val, isVal := objMap.Get("table").(*object.String); isVal && val.Value() != "" {
options.TableName = val.Value()
if val := objMap.Get("table"); val != object.Nil && val.IsTruthy() {
switch tv := val.(type) {
case *object.String:
options.TableName = tv.Value()
case *tableProxy:
options.TableName = tv.table.Name
default:
return object.Errorf("type error: query option 'table' must be either a string or table")
}
}
// Placeholders
@ -111,12 +118,26 @@ func (um *sessionModule) setResultSet(ctx context.Context, args ...object.Object
return nil
}
func (um *sessionModule) currentTable(ctx context.Context, args ...object.Object) object.Object {
if err := arg.Require("session.current_table", 0, args); err != nil {
return err
}
rs := um.sessionService.ResultSet(ctx)
if rs == nil {
return object.Nil
}
return &tableProxy{table: rs.TableInfo}
}
func (um *sessionModule) register(scp *scope.Scope) {
modScope := scope.New(scope.Opts{})
mod := object.NewModule("session", modScope)
modScope.AddBuiltins([]*object.Builtin{
object.NewBuiltin("query", um.query, mod),
object.NewBuiltin("current_table", um.currentTable, mod),
object.NewBuiltin("result_set", um.resultSet, mod),
object.NewBuiltin("selected_item", um.selectedItem, mod),
object.NewBuiltin("set_result_set", um.setResultSet, mod),

View file

@ -12,6 +12,78 @@ import (
"testing"
)
func TestModSession_Table(t *testing.T) {
t.Run("should return details of the current table", func(t *testing.T) {
tableDef := models.TableInfo{
Name: "test_table",
Keys: models.KeyAttribute{
PartitionKey: "pk",
SortKey: "sk",
},
GSIs: []models.TableGSI{
{
Name: "index-1",
Keys: models.KeyAttribute{
PartitionKey: "ipk",
SortKey: "isk",
},
},
},
}
rs := models.ResultSet{TableInfo: &tableDef}
mockedSessionService := mocks.NewSessionService(t)
mockedSessionService.EXPECT().ResultSet(mock.Anything).Return(&rs)
testFS := testScriptFile(t, "test.tm", `
table := session.current_table()
assert(table.name == "test_table")
assert(table.keys["partition"] == "pk")
assert(table.keys["sort"] == "sk")
assert(len(table.gsis) == 1)
assert(table.gsis[0].name == "index-1")
assert(table.gsis[0].keys["partition"] == "ipk")
assert(table.gsis[0].keys["sort"] == "isk")
assert(table == session.result_set().table)
`)
srv := scriptmanager.New(scriptmanager.WithFS(testFS))
srv.SetIFaces(scriptmanager.Ifaces{
Session: mockedSessionService,
})
ctx := context.Background()
err := <-srv.RunAdHocScript(ctx, "test.tm")
assert.NoError(t, err)
mockedSessionService.AssertExpectations(t)
})
t.Run("should return nil if no current result set", func(t *testing.T) {
mockedSessionService := mocks.NewSessionService(t)
mockedSessionService.EXPECT().ResultSet(mock.Anything).Return(nil)
testFS := testScriptFile(t, "test.tm", `
table := session.current_table()
assert(table == nil)
`)
srv := scriptmanager.New(scriptmanager.WithFS(testFS))
srv.SetIFaces(scriptmanager.Ifaces{
Session: mockedSessionService,
})
ctx := context.Background()
err := <-srv.RunAdHocScript(ctx, "test.tm")
assert.NoError(t, err)
mockedSessionService.AssertExpectations(t)
})
}
func TestModSession_Query(t *testing.T) {
t.Run("should successfully return query result", func(t *testing.T) {
rs := &models.ResultSet{}
@ -110,6 +182,42 @@ func TestModSession_Query(t *testing.T) {
mockedSessionService.AssertExpectations(t)
})
t.Run("should successfully specify table proxy", func(t *testing.T) {
rs := &models.ResultSet{}
mockedSessionService := mocks.NewSessionService(t)
mockedSessionService.EXPECT().ResultSet(mock.Anything).Return(&models.ResultSet{
TableInfo: &models.TableInfo{
Name: "some-resultset-table",
},
})
mockedSessionService.EXPECT().Query(mock.Anything, "some expr", scriptmanager.QueryOptions{
TableName: "some-resultset-table",
}).Return(rs, nil)
mockedUIService := mocks.NewUIService(t)
testFS := testScriptFile(t, "test.tm", `
res := session.query("some expr", {
table: session.result_set().table,
})
assert(!res.is_err())
`)
srv := scriptmanager.New(scriptmanager.WithFS(testFS))
srv.SetIFaces(scriptmanager.Ifaces{
UI: mockedUIService,
Session: mockedSessionService,
})
ctx := context.Background()
err := <-srv.RunAdHocScript(ctx, "test.tm")
assert.NoError(t, err)
mockedUIService.AssertExpectations(t)
mockedSessionService.AssertExpectations(t)
})
t.Run("should set placeholder values", func(t *testing.T) {
rs := &models.ResultSet{}

View file

@ -91,6 +91,8 @@ func (r *resultSetProxy) Iter() object.Iterator {
func (r *resultSetProxy) GetAttr(name string) (object.Object, bool) {
switch name {
case "table":
return &tableProxy{table: r.resultSet.TableInfo}, true
case "length":
return object.NewInt(int64(len(r.resultSet.Items()))), true
}

View file

@ -13,7 +13,11 @@ import (
func TestResultSetProxy(t *testing.T) {
t.Run("should property return properties of a resultset and item", func(t *testing.T) {
rs := &models.ResultSet{}
rs := &models.ResultSet{
TableInfo: &models.TableInfo{
Name: "test-table",
},
}
rs.SetItems([]models.Item{
{"pk": &types.AttributeValueMemberS{Value: "abc"}},
{"pk": &types.AttributeValueMemberS{Value: "1232"}},
@ -28,6 +32,8 @@ func TestResultSetProxy(t *testing.T) {
res := session.query("some expr").unwrap()
// Test properties of the result set
assert(res.table.name, "hello")
assert(res == res, "result_set.equals")
assert(res.length == 2, "result_set.length")

View file

@ -4,10 +4,12 @@ import (
"context"
"github.com/cloudcmds/tamarin/exec"
"github.com/cloudcmds/tamarin/scope"
"github.com/lmika/audax/internal/dynamo-browse/services/keybindings"
"github.com/pkg/errors"
"io/fs"
"os"
"path/filepath"
"strings"
)
type Service struct {
@ -119,7 +121,7 @@ func (s *Service) loadScript(ctx context.Context, filename string, resChan chan
}
newPlugin := &ScriptPlugin{
name: filepath.Base(filename),
name: strings.TrimSuffix(filepath.Base(filename), filepath.Ext(filename)),
scriptService: s,
}
@ -176,6 +178,49 @@ func (s *Service) LookupCommand(name string) *Command {
return nil
}
func (s *Service) LookupKeyBinding(key string) (string, *Command) {
for _, p := range s.plugins {
if bindingName, hasBinding := p.keyToKeyBinding[key]; hasBinding {
if cmd, hasCmd := p.definedKeyBindings[bindingName]; hasCmd {
return bindingName, cmd
}
}
}
return "", nil
}
func (s *Service) UnbindKey(key string) {
for _, p := range s.plugins {
if _, hasBinding := p.keyToKeyBinding[key]; hasBinding {
delete(p.keyToKeyBinding, key)
}
}
}
func (s *Service) RebindKeyBinding(keyBinding string, newKey string) error {
if newKey == "" {
for _, p := range s.plugins {
for k, b := range p.keyToKeyBinding {
if b == keyBinding {
delete(p.keyToKeyBinding, k)
}
}
}
return nil
}
for _, p := range s.plugins {
if _, hasCmd := p.definedKeyBindings[keyBinding]; hasCmd {
if newKey != "" {
p.keyToKeyBinding[newKey] = keyBinding
}
return nil
}
}
return keybindings.InvalidBindingError(keyBinding)
}
func (s *Service) parentScope() *scope.Scope {
scp := scope.New(scope.Opts{})
(&uiModule{uiService: s.ifaces.UI}).register(scp)

View file

@ -55,7 +55,7 @@ func TestService_LoadScript(t *testing.T) {
plugin, err := srv.LoadScript(ctx, "test.tm")
assert.NoError(t, err)
assert.NotNil(t, plugin)
assert.Equal(t, "test.tm", plugin.Name())
assert.Equal(t, "test", plugin.Name())
cmd := srv.LookupCommand("somewhere")
assert.NotNil(t, cmd)

View file

@ -0,0 +1,105 @@
package scriptmanager
import (
"github.com/cloudcmds/tamarin/object"
"github.com/lmika/audax/internal/common/sliceutils"
"github.com/lmika/audax/internal/dynamo-browse/models"
"reflect"
)
const (
tableProxyPartitionKey = "partition"
tableProxySortKey = "sort"
)
type tableProxy struct {
table *models.TableInfo
}
func (t *tableProxy) Type() object.Type {
return "table"
}
func (t *tableProxy) Inspect() string {
return "table(" + t.table.Name + ")"
}
func (t *tableProxy) Interface() interface{} {
return t.table
}
func (t *tableProxy) Equals(other object.Object) object.Object {
otherT, isOtherRS := other.(*tableProxy)
if !isOtherRS {
return object.False
}
return object.NewBool(reflect.DeepEqual(t.table, otherT.table))
}
func (t *tableProxy) GetAttr(name string) (object.Object, bool) {
switch name {
case "name":
return object.NewString(t.table.Name), true
case "keys":
return object.NewMap(map[string]object.Object{
tableProxyPartitionKey: object.NewString(t.table.Keys.PartitionKey),
tableProxySortKey: object.NewString(t.table.Keys.SortKey),
}), true
case "gsis":
return object.NewList(sliceutils.Map(t.table.GSIs, newTableIndexProxy)), true
}
return nil, false
}
func (t *tableProxy) IsTruthy() bool {
return true
}
type tableIndexProxy struct {
gsi models.TableGSI
}
func newTableIndexProxy(gsi models.TableGSI) object.Object {
return tableIndexProxy{gsi: gsi}
}
func (t tableIndexProxy) Type() object.Type {
return "index"
}
func (t tableIndexProxy) Inspect() string {
return "index(gsi," + t.gsi.Name + ")"
}
func (t tableIndexProxy) Interface() interface{} {
return t.gsi
}
func (t tableIndexProxy) Equals(other object.Object) object.Object {
otherIP, isOtherIP := other.(tableIndexProxy)
if !isOtherIP {
return object.False
}
return object.NewBool(reflect.DeepEqual(t.gsi, otherIP.gsi))
}
func (t tableIndexProxy) GetAttr(name string) (object.Object, bool) {
switch name {
case "name":
return object.NewString(t.gsi.Name), true
case "keys":
return object.NewMap(map[string]object.Object{
tableProxyPartitionKey: object.NewString(t.gsi.Keys.PartitionKey),
tableProxySortKey: object.NewString(t.gsi.Keys.SortKey),
}), true
}
return nil, false
}
func (t tableIndexProxy) IsTruthy() bool {
return true
}

View file

@ -6,6 +6,8 @@ type ScriptPlugin struct {
scriptService *Service
name string
definedCommands map[string]*Command
definedKeyBindings map[string]*Command
keyToKeyBinding map[string]string
}
func (sp *ScriptPlugin) Name() string {

View file

@ -59,6 +59,7 @@ type Model struct {
itemView *dynamoitemview.Model
mainView tea.Model
keyMap *keybindings.ViewKeyBindings
keyBindingController *controllers.KeyBindingController
}
func NewModel(
@ -231,6 +232,7 @@ func NewModel(
itemView: div,
mainView: mainView,
keyMap: defaultKeyMap.View,
keyBindingController: keyBindingController,
}
}
@ -285,6 +287,10 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, events.SetTeaMessage(m.jobController.CancelRunningJob(m.promptToQuit))
case key.Matches(msg, m.keyMap.Quit):
return m, m.promptToQuit
default:
if cmd := m.keyBindingController.LookupCustomBinding(msg.String()); cmd != nil {
return m, cmd
}
}
}
}