浏览代码

refactor: dependency injection :)

ayn2op 1 年之前
父节点
当前提交
1e4d22fb77
共有 7 个文件被更改,包括 79 次插入69 次删除
  1. 12 9
      cmd/guilds_tree.go
  2. 16 15
      cmd/layout.go
  3. 2 1
      cmd/login_form.go
  4. 7 5
      cmd/message_input.go
  5. 34 31
      cmd/messages_text.go
  6. 3 6
      cmd/run.go
  7. 5 2
      cmd/state.go

+ 12 - 9
cmd/guilds_tree.go

@@ -6,6 +6,7 @@ import (
 	"sort"
 	"strings"
 
+	"github.com/ayn2op/discordo/internal/config"
 	"github.com/diamondburned/arikawa/v3/discord"
 	"github.com/diamondburned/arikawa/v3/gateway"
 	"github.com/gdamore/tcell/v2"
@@ -14,13 +15,15 @@ import (
 
 type GuildsTree struct {
 	*tview.TreeView
+	cfg               *config.Config
 	app               *tview.Application
 	selectedChannelID discord.ChannelID
 }
 
-func newGuildsTree(app *tview.Application) *GuildsTree {
+func newGuildsTree(app *tview.Application, cfg *config.Config) *GuildsTree {
 	gt := &GuildsTree{
 		TreeView: tview.NewTreeView(),
+		cfg:      cfg,
 		app:      app,
 	}
 
@@ -55,7 +58,7 @@ func (gt *GuildsTree) createFolderNode(folder gateway.GuildFolder) {
 
 	root := gt.GetRoot()
 	folderNode := tview.NewTreeNode(name)
-	folderNode.SetExpanded(cfg.Theme.GuildsTree.AutoExpandFolders)
+	folderNode.SetExpanded(gt.cfg.Theme.GuildsTree.AutoExpandFolders)
 	root.AddChild(folderNode)
 
 	for _, gID := range folder.GuildIDs {
@@ -72,7 +75,7 @@ func (gt *GuildsTree) createFolderNode(folder gateway.GuildFolder) {
 func (gt *GuildsTree) createGuildNode(n *tview.TreeNode, g discord.Guild) {
 	guildNode := tview.NewTreeNode(g.Name)
 	guildNode.SetReference(g.ID)
-	guildNode.SetColor(tcell.GetColor(cfg.Theme.GuildsTree.GuildColor))
+	guildNode.SetColor(tcell.GetColor(gt.cfg.Theme.GuildsTree.GuildColor))
 	n.AddChild(guildNode)
 }
 
@@ -125,7 +128,7 @@ func (gt *GuildsTree) createChannelNode(n *tview.TreeNode, c discord.Channel) *t
 
 	channelNode := tview.NewTreeNode(gt.channelToString(c))
 	channelNode.SetReference(c.ID)
-	channelNode.SetColor(tcell.GetColor(cfg.Theme.GuildsTree.ChannelColor))
+	channelNode.SetColor(tcell.GetColor(gt.cfg.Theme.GuildsTree.ChannelColor))
 	n.AddChild(channelNode)
 	return channelNode
 }
@@ -226,16 +229,16 @@ func (gt *GuildsTree) onSelected(n *tview.TreeNode) {
 
 func (gt *GuildsTree) onInputCapture(event *tcell.EventKey) *tcell.EventKey {
 	switch event.Name() {
-	case cfg.Keys.SelectPrevious:
+	case gt.cfg.Keys.SelectPrevious:
 		return tcell.NewEventKey(tcell.KeyUp, 0, tcell.ModNone)
-	case cfg.Keys.SelectNext:
+	case gt.cfg.Keys.SelectNext:
 		return tcell.NewEventKey(tcell.KeyDown, 0, tcell.ModNone)
-	case cfg.Keys.SelectFirst:
+	case gt.cfg.Keys.SelectFirst:
 		return tcell.NewEventKey(tcell.KeyHome, 0, tcell.ModNone)
-	case cfg.Keys.SelectLast:
+	case gt.cfg.Keys.SelectLast:
 		return tcell.NewEventKey(tcell.KeyEnd, 0, tcell.ModNone)
 
-	case cfg.Keys.GuildsTree.SelectCurrent:
+	case gt.cfg.Keys.GuildsTree.SelectCurrent:
 		return tcell.NewEventKey(tcell.KeyEnter, 0, tcell.ModNone)
 	}
 

+ 16 - 15
cmd/layout.go

@@ -3,6 +3,7 @@ package cmd
 import (
 	"log/slog"
 
+	"github.com/ayn2op/discordo/internal/config"
 	"github.com/ayn2op/discordo/internal/constants"
 	"github.com/gdamore/tcell/v2"
 	"github.com/rivo/tview"
@@ -10,23 +11,23 @@ import (
 )
 
 type Layout struct {
-	app  *tview.Application
-	flex *tview.Flex
-
+	cfg          *config.Config
+	app          *tview.Application
+	flex         *tview.Flex
 	guildsTree   *GuildsTree
 	messagesText *MessagesText
 	messageInput *MessageInput
 }
 
-func newLayout() *Layout {
+func newLayout(cfg *config.Config) *Layout {
 	app := tview.NewApplication()
 	l := &Layout{
 		app:  app,
 		flex: tview.NewFlex(),
 
-		guildsTree:   newGuildsTree(app),
-		messagesText: newMessagesText(app),
-		messageInput: newMessageInput(app),
+		guildsTree:   newGuildsTree(app, cfg),
+		messagesText: newMessagesText(app, cfg),
+		messageInput: newMessageInput(app, cfg),
 	}
 
 	l.init()
@@ -49,10 +50,10 @@ func (l *Layout) show(token string) error {
 			if err := l.show(token); err != nil {
 				slog.Error("failed to show app", "err", err)
 			}
-		})
+		}, l.cfg)
 		l.app.SetRoot(loginForm, true)
 	} else {
-		if err := openState(token, l.app); err != nil {
+		if err := openState(token, l.app, l.cfg); err != nil {
 			return err
 		}
 
@@ -84,7 +85,7 @@ func (l *Layout) init() {
 
 func (l *Layout) onAppInputCapture(event *tcell.EventKey) *tcell.EventKey {
 	switch event.Name() {
-	case cfg.Keys.Quit:
+	case l.cfg.Keys.Quit:
 		l.app.Stop()
 	case "Ctrl+C":
 		// https://github.com/rivo/tview/blob/a64fc48d7654432f71922c8b908280cdb525805c/application.go#L153
@@ -96,16 +97,16 @@ func (l *Layout) onAppInputCapture(event *tcell.EventKey) *tcell.EventKey {
 
 func (l *Layout) onFlexInputCapture(event *tcell.EventKey) *tcell.EventKey {
 	switch event.Name() {
-	case cfg.Keys.FocusGuildsTree:
+	case l.cfg.Keys.FocusGuildsTree:
 		l.app.SetFocus(l.guildsTree)
 		return nil
-	case cfg.Keys.FocusMessagesText:
+	case l.cfg.Keys.FocusMessagesText:
 		l.app.SetFocus(l.messagesText)
 		return nil
-	case cfg.Keys.FocusMessageInput:
+	case l.cfg.Keys.FocusMessageInput:
 		l.app.SetFocus(l.messageInput)
 		return nil
-	case cfg.Keys.Logout:
+	case l.cfg.Keys.Logout:
 		l.app.Stop()
 
 		if err := keyring.Delete(constants.Name, "token"); err != nil {
@@ -114,7 +115,7 @@ func (l *Layout) onFlexInputCapture(event *tcell.EventKey) *tcell.EventKey {
 		}
 
 		return nil
-	case cfg.Keys.ToggleGuildsTree:
+	case l.cfg.Keys.ToggleGuildsTree:
 		// The guilds tree is visible if the numbers of items is two.
 		if l.flex.GetItemCount() == 2 {
 			l.flex.RemoveItem(l.guildsTree)

+ 2 - 1
cmd/login_form.go

@@ -3,6 +3,7 @@ package cmd
 import (
 	"errors"
 
+	"github.com/ayn2op/discordo/internal/config"
 	"github.com/ayn2op/discordo/internal/constants"
 	"github.com/diamondburned/arikawa/v3/api"
 	"github.com/gdamore/tcell/v2"
@@ -17,7 +18,7 @@ type loginForm struct {
 	done doneFn
 }
 
-func newLoginForm(done doneFn) *loginForm {
+func newLoginForm(done doneFn, cfg *config.Config) *loginForm {
 	if done == nil {
 		done = func(_ string, _ error) {}
 	}

+ 7 - 5
cmd/message_input.go

@@ -7,6 +7,7 @@ import (
 	"strings"
 
 	"github.com/atotto/clipboard"
+	"github.com/ayn2op/discordo/internal/config"
 	"github.com/ayn2op/discordo/internal/constants"
 	"github.com/diamondburned/arikawa/v3/api"
 	"github.com/diamondburned/arikawa/v3/discord"
@@ -17,11 +18,12 @@ import (
 
 type MessageInput struct {
 	*tview.TextArea
+	cfg            *config.Config
 	app            *tview.Application
 	replyMessageID discord.MessageID
 }
 
-func newMessageInput(app *tview.Application) *MessageInput {
+func newMessageInput(app *tview.Application, cfg *config.Config) *MessageInput {
 	mi := &MessageInput{
 		TextArea: tview.NewTextArea(),
 		app:      app,
@@ -57,13 +59,13 @@ func (mi *MessageInput) reset() {
 
 func (mi *MessageInput) onInputCapture(event *tcell.EventKey) *tcell.EventKey {
 	switch event.Name() {
-	case cfg.Keys.MessageInput.Send:
+	case mi.cfg.Keys.MessageInput.Send:
 		mi.send()
 		return nil
-	case cfg.Keys.MessageInput.Editor:
+	case mi.cfg.Keys.MessageInput.Editor:
 		mi.editor()
 		return nil
-	case cfg.Keys.MessageInput.Cancel:
+	case mi.cfg.Keys.MessageInput.Cancel:
 		mi.reset()
 		return nil
 	}
@@ -113,7 +115,7 @@ func (mi *MessageInput) send() {
 }
 
 func (mi *MessageInput) editor() {
-	e := cfg.Editor
+	e := mi.cfg.Editor
 	if e == "default" {
 		e = os.Getenv("EDITOR")
 	}

+ 34 - 31
cmd/messages_text.go

@@ -10,6 +10,7 @@ import (
 	"time"
 
 	"github.com/atotto/clipboard"
+	"github.com/ayn2op/discordo/internal/config"
 	"github.com/ayn2op/discordo/internal/markdown"
 	"github.com/diamondburned/arikawa/v3/discord"
 	"github.com/diamondburned/ningen/v3/discordmd"
@@ -21,13 +22,15 @@ import (
 
 type MessagesText struct {
 	*tview.TextView
+	cfg               *config.Config
 	app               *tview.Application
 	selectedMessageID discord.MessageID
 }
 
-func newMessagesText(app *tview.Application) *MessagesText {
+func newMessagesText(app *tview.Application, cfg *config.Config) *MessagesText {
 	mt := &MessagesText{
 		TextView: tview.NewTextView(),
+		cfg:      cfg,
 		app:      app,
 	}
 
@@ -40,21 +43,21 @@ func newMessagesText(app *tview.Application) *MessagesText {
 		app.Draw()
 	})
 
-	mt.SetTextColor(tcell.GetColor(cfg.Theme.MessagesText.ContentColor))
-	mt.SetBackgroundColor(tcell.GetColor(cfg.Theme.BackgroundColor))
+	mt.SetTextColor(tcell.GetColor(mt.cfg.Theme.MessagesText.ContentColor))
+	mt.SetBackgroundColor(tcell.GetColor(mt.cfg.Theme.BackgroundColor))
 
 	mt.SetTitle("Messages")
-	mt.SetTitleColor(tcell.GetColor(cfg.Theme.TitleColor))
+	mt.SetTitleColor(tcell.GetColor(mt.cfg.Theme.TitleColor))
 	mt.SetTitleAlign(tview.AlignLeft)
 
-	p := cfg.Theme.BorderPadding
-	mt.SetBorder(cfg.Theme.Border)
-	mt.SetBorderColor(tcell.GetColor(cfg.Theme.BorderColor))
+	p := mt.cfg.Theme.BorderPadding
+	mt.SetBorder(mt.cfg.Theme.Border)
+	mt.SetBorderColor(tcell.GetColor(mt.cfg.Theme.BorderColor))
 	mt.SetBorderPadding(p[0], p[1], p[2], p[3])
 
 	markdown.DefaultRenderer.AddOptions(
-		renderer.WithOption("emojiColor", cfg.Theme.MessagesText.EmojiColor),
-		renderer.WithOption("linkColor", cfg.Theme.MessagesText.LinkColor),
+		renderer.WithOption("emojiColor", mt.cfg.Theme.MessagesText.EmojiColor),
+		renderer.WithOption("linkColor", mt.cfg.Theme.MessagesText.LinkColor),
 	)
 
 	mt.SetHighlightedFunc(mt.onHighlighted)
@@ -63,7 +66,7 @@ func newMessagesText(app *tview.Application) *MessagesText {
 }
 
 func (mt *MessagesText) drawMsgs(cID discord.ChannelID) {
-	ms, err := discordState.Messages(cID, uint(cfg.MessagesLimit))
+	ms, err := discordState.Messages(cID, uint(mt.cfg.MessagesLimit))
 	if err != nil {
 		slog.Error("failed to get messages", "err", err, "channel_id", cID)
 		return
@@ -97,7 +100,7 @@ func (mt *MessagesText) createMessage(m discord.Message) {
 	mt.startRegion(m.ID)
 	defer mt.endRegion()
 
-	if cfg.HideBlockedUsers {
+	if mt.cfg.HideBlockedUsers {
 		isBlocked := discordState.UserIsBlocked(m.Author.ID)
 		if isBlocked {
 			fmt.Fprintln(mt, "[:red:b]Blocked message[:-:-]")
@@ -107,7 +110,7 @@ func (mt *MessagesText) createMessage(m discord.Message) {
 
 	switch m.Type {
 	case discord.ChannelPinnedMessage:
-		fmt.Fprint(mt, "["+cfg.Theme.MessagesText.ContentColor+"]"+m.Author.Username+" pinned a message"+"[-:-:-]")
+		fmt.Fprint(mt, "["+mt.cfg.Theme.MessagesText.ContentColor+"]"+m.Author.Username+" pinned a message"+"[-:-:-]")
 	case discord.DefaultMessage, discord.InlinedReplyMessage:
 		if m.ReferencedMessage != nil {
 			mt.createHeader(mt, *m.ReferencedMessage, true)
@@ -127,16 +130,16 @@ func (mt *MessagesText) createMessage(m discord.Message) {
 }
 
 func (mt *MessagesText) createHeader(w io.Writer, m discord.Message, isReply bool) {
-	if cfg.Timestamps {
-		time := m.Timestamp.Time().In(time.Local).Format(cfg.TimestampsFormat)
+	if mt.cfg.Timestamps {
+		time := m.Timestamp.Time().In(time.Local).Format(mt.cfg.TimestampsFormat)
 		fmt.Fprintf(w, "[::d]%s[::-] ", time)
 	}
 
 	if isReply {
-		fmt.Fprintf(mt, "[::d]%s", cfg.Theme.MessagesText.ReplyIndicator)
+		fmt.Fprintf(mt, "[::d]%s", mt.cfg.Theme.MessagesText.ReplyIndicator)
 	}
 
-	fmt.Fprintf(w, "[%s]%s[-:-:-] ", cfg.Theme.MessagesText.AuthorColor, m.Author.Username)
+	fmt.Fprintf(w, "[%s]%s[-:-:-] ", mt.cfg.Theme.MessagesText.AuthorColor, m.Author.Username)
 }
 
 func (mt *MessagesText) createBody(w io.Writer, m discord.Message, isReply bool) {
@@ -156,10 +159,10 @@ func (mt *MessagesText) createBody(w io.Writer, m discord.Message, isReply bool)
 func (mt *MessagesText) createFooter(w io.Writer, m discord.Message) {
 	for _, a := range m.Attachments {
 		fmt.Fprintln(w)
-		if cfg.ShowAttachmentLinks {
-			fmt.Fprintf(w, "[%s][%s]:\n%s[-]", cfg.Theme.MessagesText.AttachmentColor, a.Filename, a.URL)
+		if mt.cfg.ShowAttachmentLinks {
+			fmt.Fprintf(w, "[%s][%s]:\n%s[-]", mt.cfg.Theme.MessagesText.AttachmentColor, a.Filename, a.URL)
 		} else {
-			fmt.Fprintf(w, "[%s][%s][-]", cfg.Theme.MessagesText.AttachmentColor, a.Filename)
+			fmt.Fprintf(w, "[%s][%s][-]", mt.cfg.Theme.MessagesText.AttachmentColor, a.Filename)
 		}
 	}
 }
@@ -194,22 +197,22 @@ func (mt *MessagesText) getSelectedMessageIndex() (int, error) {
 
 func (mt *MessagesText) onInputCapture(event *tcell.EventKey) *tcell.EventKey {
 	switch event.Name() {
-	case cfg.Keys.SelectPrevious, cfg.Keys.SelectNext, cfg.Keys.SelectFirst, cfg.Keys.SelectLast, cfg.Keys.MessagesText.SelectReply, cfg.Keys.MessagesText.SelectPin:
+	case mt.cfg.Keys.SelectPrevious, mt.cfg.Keys.SelectNext, mt.cfg.Keys.SelectFirst, mt.cfg.Keys.SelectLast, mt.cfg.Keys.MessagesText.SelectReply, mt.cfg.Keys.MessagesText.SelectPin:
 		mt._select(event.Name())
 		return nil
-	case cfg.Keys.MessagesText.Yank:
+	case mt.cfg.Keys.MessagesText.Yank:
 		mt.yank()
 		return nil
-	case cfg.Keys.MessagesText.Open:
+	case mt.cfg.Keys.MessagesText.Open:
 		mt.open()
 		return nil
-	case cfg.Keys.MessagesText.Reply:
+	case mt.cfg.Keys.MessagesText.Reply:
 		mt.reply(false)
 		return nil
-	case cfg.Keys.MessagesText.ReplyMention:
+	case mt.cfg.Keys.MessagesText.ReplyMention:
 		mt.reply(true)
 		return nil
-	case cfg.Keys.MessagesText.Delete:
+	case mt.cfg.Keys.MessagesText.Delete:
 		mt.delete()
 		return nil
 	}
@@ -231,7 +234,7 @@ func (mt *MessagesText) _select(name string) {
 	}
 
 	switch name {
-	case cfg.Keys.SelectPrevious:
+	case mt.cfg.Keys.SelectPrevious:
 		// If no message is currently selected, select the latest message.
 		if len(mt.GetHighlights()) == 0 {
 			mt.selectedMessageID = ms[0].ID
@@ -242,7 +245,7 @@ func (mt *MessagesText) _select(name string) {
 				return
 			}
 		}
-	case cfg.Keys.SelectNext:
+	case mt.cfg.Keys.SelectNext:
 		// If no message is currently selected, select the latest message.
 		if len(mt.GetHighlights()) == 0 {
 			mt.selectedMessageID = ms[0].ID
@@ -253,11 +256,11 @@ func (mt *MessagesText) _select(name string) {
 				return
 			}
 		}
-	case cfg.Keys.SelectFirst:
+	case mt.cfg.Keys.SelectFirst:
 		mt.selectedMessageID = ms[len(ms)-1].ID
-	case cfg.Keys.SelectLast:
+	case mt.cfg.Keys.SelectLast:
 		mt.selectedMessageID = ms[0].ID
-	case cfg.Keys.MessagesText.SelectReply:
+	case mt.cfg.Keys.MessagesText.SelectReply:
 		if mt.selectedMessageID == 0 {
 			return
 		}
@@ -269,7 +272,7 @@ func (mt *MessagesText) _select(name string) {
 				}
 			}
 		}
-	case cfg.Keys.MessagesText.SelectPin:
+	case mt.cfg.Keys.MessagesText.SelectPin:
 		if ref := ms[messageIdx].Reference; ref != nil {
 			for _, m := range ms {
 				if ref.MessageID == m.ID {

+ 3 - 6
cmd/run.go

@@ -7,9 +7,7 @@ import (
 
 var (
 	discordState *State
-
-	cfg      *config.Config
-	mainFlex *Layout
+	mainFlex     *Layout
 )
 
 func Run(token string) error {
@@ -17,12 +15,11 @@ func Run(token string) error {
 		return err
 	}
 
-	var err error
-	cfg, err = config.Load()
+	cfg, err := config.Load()
 	if err != nil {
 		return err
 	}
 
-	mainFlex = newLayout()
+	mainFlex = newLayout(cfg)
 	return mainFlex.run(token)
 }

+ 5 - 2
cmd/state.go

@@ -6,6 +6,7 @@ import (
 	"runtime"
 	"slices"
 
+	"github.com/ayn2op/discordo/internal/config"
 	"github.com/ayn2op/discordo/internal/constants"
 	"github.com/diamondburned/arikawa/v3/api"
 	"github.com/diamondburned/arikawa/v3/discord"
@@ -27,12 +28,14 @@ func init() {
 
 type State struct {
 	*ningen.State
+	cfg *config.Config
 	app *tview.Application
 }
 
-func openState(token string, app *tview.Application) error {
+func openState(token string, app *tview.Application, cfg *config.Config) error {
 	discordState = &State{
 		State: ningen.New(token),
+		cfg:   cfg,
 		app:   app,
 	}
 
@@ -57,7 +60,7 @@ func (s *State) onRequest(r httpdriver.Request) error {
 func (s *State) onReady(r *gateway.ReadyEvent) {
 	root := mainFlex.guildsTree.GetRoot()
 	dmNode := tview.NewTreeNode("Direct Messages")
-	dmNode.SetColor(tcell.GetColor(cfg.Theme.GuildsTree.PrivateChannelColor))
+	dmNode.SetColor(tcell.GetColor(s.cfg.Theme.GuildsTree.PrivateChannelColor))
 	root.AddChild(dmNode)
 
 	// Track guilds that have a parent (folder) to add orphan channels later