diff --git a/db.go b/db.go index 9bb08ab..c093064 100644 --- a/db.go +++ b/db.go @@ -3,6 +3,7 @@ package soju import ( "database/sql" "fmt" + "strings" "sync" _ "github.com/mattn/go-sqlite3" @@ -23,14 +24,15 @@ type SASL struct { } type Network struct { - ID int64 - Name string - Addr string - Nick string - Username string - Realname string - Pass string - SASL SASL + ID int64 + Name string + Addr string + Nick string + Username string + Realname string + Pass string + ConnectCommands []string + SASL SASL } func (net *Network) GetName() string { @@ -63,6 +65,7 @@ CREATE TABLE Network ( username VARCHAR(255), realname VARCHAR(255), pass VARCHAR(255), + connect_commands VARCHAR(1023), sasl_mechanism VARCHAR(255), sasl_plain_username VARCHAR(255), sasl_plain_password VARCHAR(255), @@ -82,6 +85,7 @@ CREATE TABLE Channel ( var migrations = []string{ "", // migration #0 is reserved for schema initialization + "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)", } type DB struct { @@ -233,7 +237,7 @@ func (db *DB) ListNetworks(username string) ([]Network, error) { defer db.lock.RUnlock() rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass, - sasl_mechanism, sasl_plain_username, sasl_plain_password + connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password FROM Network WHERE user = ?`, username) @@ -245,10 +249,10 @@ func (db *DB) ListNetworks(username string) ([]Network, error) { var networks []Network for rows.Next() { var net Network - var name, username, realname, pass *string + var name, username, realname, pass, connectCommands *string var saslMechanism, saslPlainUsername, saslPlainPassword *string err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname, - &pass, &saslMechanism, &saslPlainUsername, &saslPlainPassword) + &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword) if err != nil { return nil, err } @@ -256,6 +260,9 @@ func (db *DB) ListNetworks(username string) ([]Network, error) { net.Username = fromStringPtr(username) net.Realname = fromStringPtr(realname) net.Pass = fromStringPtr(pass) + if connectCommands != nil { + net.ConnectCommands = strings.Split(*connectCommands, "\r\n") + } net.SASL.Mechanism = fromStringPtr(saslMechanism) net.SASL.Plain.Username = fromStringPtr(saslPlainUsername) net.SASL.Plain.Password = fromStringPtr(saslPlainPassword) @@ -276,6 +283,7 @@ func (db *DB) StoreNetwork(username string, network *Network) error { netUsername := toStringPtr(network.Username) realname := toStringPtr(network.Realname) pass := toStringPtr(network.Pass) + connectCommands := toStringPtr(strings.Join(network.ConnectCommands, "\r\n")) var saslMechanism, saslPlainUsername, saslPlainPassword *string if network.SASL.Mechanism != "" { @@ -292,18 +300,18 @@ func (db *DB) StoreNetwork(username string, network *Network) error { var err error if network.ID != 0 { _, err = db.db.Exec(`UPDATE Network - SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, + SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, connect_commands = ?, sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ? WHERE id = ?`, - netName, network.Addr, network.Nick, netUsername, realname, pass, + netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands, saslMechanism, saslPlainUsername, saslPlainPassword, network.ID) } else { var res sql.Result res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username, - realname, pass, sasl_mechanism, sasl_plain_username, + realname, pass, connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, - username, netName, network.Addr, network.Nick, netUsername, realname, pass, + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + username, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands, saslMechanism, saslPlainUsername, saslPlainPassword) if err != nil { return err diff --git a/service.go b/service.go index 4e73873..ecbd6aa 100644 --- a/service.go +++ b/service.go @@ -104,7 +104,7 @@ func init() { "network": { children: serviceCommandSet{ "create": { - usage: "-addr [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick]", + usage: "-addr [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [[-connect-command command] ...]", desc: "add a new network", handle: handleServiceCreateNetwork, }, @@ -174,6 +174,17 @@ func newFlagSet() *flag.FlagSet { return fs } +type stringSliceVar []string + +func (v *stringSliceVar) String() string { + return fmt.Sprint([]string(*v)) +} + +func (v *stringSliceVar) Set(s string) error { + *v = append(*v, s) + return nil +} + func handleServiceCreateNetwork(dc *downstreamConn, params []string) error { fs := newFlagSet() addr := fs.String("addr", "", "") @@ -182,6 +193,8 @@ func handleServiceCreateNetwork(dc *downstreamConn, params []string) error { pass := fs.String("pass", "", "") realname := fs.String("realname", "", "") nick := fs.String("nick", "", "") + var connectCommands stringSliceVar + fs.Var(&connectCommands, "connect-command", "") if err := fs.Parse(params); err != nil { return err @@ -190,18 +203,26 @@ func handleServiceCreateNetwork(dc *downstreamConn, params []string) error { return fmt.Errorf("flag -addr is required") } + for _, command := range connectCommands { + _, err := irc.ParseMessage(command) + if err != nil { + return fmt.Errorf("flag -connect-command must be a valid raw irc command string: %q: %v", command, err) + } + } + if *nick == "" { *nick = dc.nick } var err error network, err := dc.user.createNetwork(&Network{ - Addr: *addr, - Name: *name, - Username: *username, - Pass: *pass, - Realname: *realname, - Nick: *nick, + Addr: *addr, + Name: *name, + Username: *username, + Pass: *pass, + Realname: *realname, + Nick: *nick, + ConnectCommands: connectCommands, }) if err != nil { return fmt.Errorf("could not create network: %v", err) diff --git a/upstream.go b/upstream.go index e563955..599a9d7 100644 --- a/upstream.go +++ b/upstream.go @@ -1189,6 +1189,15 @@ func (uc *upstreamConn) runUntilRegistered() error { } } + for _, command := range uc.network.ConnectCommands { + m, err := irc.ParseMessage(command) + if err != nil { + uc.logger.Printf("failed to parse connect command %q: %v", command, err) + } else { + uc.SendMessage(m) + } + } + return nil }