diff --git a/database/database.go b/database/database.go index b4a6ac8..061d529 100644 --- a/database/database.go +++ b/database/database.go @@ -130,6 +130,7 @@ type Network struct { Realname string Pass string ConnectCommands []string + CertFP string SASL SASL AutoAway bool Enabled bool diff --git a/database/postgres.go b/database/postgres.go index 9af5c4a..ce940bb 100644 --- a/database/postgres.go +++ b/database/postgres.go @@ -44,6 +44,7 @@ CREATE TABLE "Network" ( nick VARCHAR(255), username VARCHAR(255), realname VARCHAR(255), + certfp TEXT, pass VARCHAR(255), connect_commands VARCHAR(1023), sasl_mechanism sasl_mechanism, @@ -165,6 +166,7 @@ var postgresMigrations = []string{ SET NOT NULL; `, `ALTER TABLE "Network" ADD COLUMN auto_away BOOLEAN NOT NULL DEFAULT TRUE`, + `ALTER TABLE "Network" ADD COLUMN certfp TEXT`, } type PostgresDB struct { @@ -380,7 +382,7 @@ func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network defer cancel() rows, err := db.db.QueryContext(ctx, ` - SELECT id, name, addr, nick, username, realname, pass, connect_commands, sasl_mechanism, + SELECT id, name, addr, nick, username, realname, certfp, pass, connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key, auto_away, enabled FROM "Network" WHERE "user" = $1`, userID) @@ -392,9 +394,9 @@ func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network var networks []Network for rows.Next() { var net Network - var name, nick, username, realname, pass, connectCommands sql.NullString + var name, nick, username, realname, certfp, pass, connectCommands sql.NullString var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString - err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname, + err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname, &certfp, &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword, &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.AutoAway, &net.Enabled) if err != nil { @@ -404,6 +406,7 @@ func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network net.Nick = nick.String net.Username = username.String net.Realname = realname.String + net.CertFP = certfp.String net.Pass = pass.String if connectCommands.Valid { net.ConnectCommands = strings.Split(connectCommands.String, "\r\n") @@ -428,6 +431,7 @@ func (db *PostgresDB) StoreNetwork(ctx context.Context, userID int64, network *N nick := toNullString(network.Nick) netUsername := toNullString(network.Username) realname := toNullString(network.Realname) + certfp := toNullString(network.CertFP) pass := toNullString(network.Pass) connectCommands := toNullString(strings.Join(network.ConnectCommands, "\r\n")) @@ -450,23 +454,23 @@ func (db *PostgresDB) StoreNetwork(ctx context.Context, userID int64, network *N var err error if network.ID == 0 { err = db.db.QueryRowContext(ctx, ` - INSERT INTO "Network" ("user", name, addr, nick, username, realname, pass, connect_commands, + INSERT INTO "Network" ("user", name, addr, nick, username, realname, certfp, pass, connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key, auto_away, enabled) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) RETURNING id`, - userID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands, + userID, netName, network.Addr, nick, netUsername, realname, certfp, pass, connectCommands, saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob, network.SASL.External.PrivKeyBlob, network.AutoAway, network.Enabled).Scan(&network.ID) } else { _, err = db.db.ExecContext(ctx, ` UPDATE "Network" - SET name = $2, addr = $3, nick = $4, username = $5, realname = $6, pass = $7, - connect_commands = $8, sasl_mechanism = $9, sasl_plain_username = $10, - sasl_plain_password = $11, sasl_external_cert = $12, sasl_external_key = $13, + SET name = $2, addr = $3, nick = $4, username = $5, realname = $6, certfp = $7, pass = $8, + connect_commands = $9, sasl_mechanism = $10, sasl_plain_username = $11, + sasl_plain_password = $12, sasl_external_cert = $13, sasl_external_key = $14, auto_away = $14, enabled = $15 WHERE id = $1`, - network.ID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands, + network.ID, netName, network.Addr, nick, netUsername, realname, certfp, pass, connectCommands, saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob, network.SASL.External.PrivKeyBlob, network.AutoAway, network.Enabled) } diff --git a/database/sqlite.go b/database/sqlite.go index 90b05b1..11e83ee 100644 --- a/database/sqlite.go +++ b/database/sqlite.go @@ -42,6 +42,7 @@ CREATE TABLE Network ( nick TEXT, username TEXT, realname TEXT, + certfp TEXT, pass TEXT, connect_commands TEXT, sasl_mechanism TEXT, @@ -250,6 +251,7 @@ var sqliteMigrations = []string{ `, "ALTER TABLE User ADD COLUMN nick TEXT;", "ALTER TABLE Network ADD COLUMN auto_away INTEGER NOT NULL DEFAULT 1;", + "ALTER TABLE Network ADD COLUMN certfp TEXT;", } type SqliteDB struct { @@ -488,7 +490,7 @@ func (db *SqliteDB) ListNetworks(ctx context.Context, userID int64) ([]Network, defer cancel() rows, err := db.db.QueryContext(ctx, ` - SELECT id, name, addr, nick, username, realname, pass, + SELECT id, name, addr, nick, username, realname, certfp, pass, connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key, auto_away, enabled FROM Network @@ -502,9 +504,9 @@ func (db *SqliteDB) ListNetworks(ctx context.Context, userID int64) ([]Network, var networks []Network for rows.Next() { var net Network - var name, nick, username, realname, pass, connectCommands sql.NullString + var name, nick, username, realname, certfp, pass, connectCommands sql.NullString var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString - err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname, + err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname, &certfp, &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword, &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.AutoAway, &net.Enabled) if err != nil { @@ -514,6 +516,7 @@ func (db *SqliteDB) ListNetworks(ctx context.Context, userID int64) ([]Network, net.Nick = nick.String net.Username = username.String net.Realname = realname.String + net.CertFP = certfp.String net.Pass = pass.String if connectCommands.Valid { net.ConnectCommands = strings.Split(connectCommands.String, "\r\n") @@ -556,6 +559,7 @@ func (db *SqliteDB) StoreNetwork(ctx context.Context, userID int64, network *Net sql.Named("nick", toNullString(network.Nick)), sql.Named("username", toNullString(network.Username)), sql.Named("realname", toNullString(network.Realname)), + sql.Named("certfp", toNullString(network.CertFP)), sql.Named("pass", toNullString(network.Pass)), sql.Named("connect_commands", toNullString(strings.Join(network.ConnectCommands, "\r\n"))), sql.Named("sasl_mechanism", saslMechanism), @@ -575,7 +579,7 @@ func (db *SqliteDB) StoreNetwork(ctx context.Context, userID int64, network *Net _, err = db.db.ExecContext(ctx, ` UPDATE Network SET name = :name, addr = :addr, nick = :nick, username = :username, - realname = :realname, pass = :pass, connect_commands = :connect_commands, + realname = :realname, certfp = :certfp, pass = :pass, connect_commands = :connect_commands, sasl_mechanism = :sasl_mechanism, sasl_plain_username = :sasl_plain_username, sasl_plain_password = :sasl_plain_password, sasl_external_cert = :sasl_external_cert, sasl_external_key = :sasl_external_key, auto_away = :auto_away, enabled = :enabled @@ -583,10 +587,10 @@ func (db *SqliteDB) StoreNetwork(ctx context.Context, userID int64, network *Net } else { var res sql.Result res, err = db.db.ExecContext(ctx, ` - INSERT INTO Network(user, name, addr, nick, username, realname, pass, + INSERT INTO Network(user, name, addr, nick, username, realname, certfp, pass, connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key, auto_away, enabled) - VALUES (:user, :name, :addr, :nick, :username, :realname, :pass, + VALUES (:user, :name, :addr, :nick, :username, :realname, :certfp, :pass, :connect_commands, :sasl_mechanism, :sasl_plain_username, :sasl_plain_password, :sasl_external_cert, :sasl_external_key, :auto_away, :enabled)`, args...) diff --git a/doc/soju.1.scd b/doc/soju.1.scd index 4b699b7..3843a5c 100644 --- a/doc/soju.1.scd +++ b/doc/soju.1.scd @@ -213,6 +213,12 @@ abbreviated form, for instance *network* can be abbreviated as *net* or just Connect with the specified real name. By default, the account's realname is used if set, otherwise the network's nickname is used. + *-certfp* + Instead of using certificate authorities to check the server's TLS + certificate, check whether the server certificate matches the provided + fingerprint. This can be used to connect to servers using self-signed + certificates. The fingerprint format is SHA512. + *-nick* Connect with the specified nickname. By default, the account's username is used. diff --git a/service.go b/service.go index 64cb110..346310f 100644 --- a/service.go +++ b/service.go @@ -201,7 +201,7 @@ func init() { "network": { children: serviceCommandSet{ "create": { - usage: "-addr [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-auto-away auto-away] [-enabled enabled] [-connect-command command]...", + usage: "-addr [-name name] [-username username] [-pass pass] [-realname realname] [-certfp fingerprint] [-nick nick] [-auto-away auto-away] [-enabled enabled] [-connect-command command]...", desc: "add a new network", handle: handleServiceNetworkCreate, }, @@ -210,7 +210,7 @@ func init() { handle: handleServiceNetworkStatus, }, "update": { - usage: "[name] [-addr addr] [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-auto-away auto-away] [-enabled enabled] [-connect-command command]...", + usage: "[name] [-addr addr] [-name name] [-username username] [-pass pass] [-realname realname] [-certfp fingerprint] [-nick nick] [-auto-away auto-away] [-enabled enabled] [-connect-command command]...", desc: "update a network", handle: handleServiceNetworkUpdate, }, @@ -435,9 +435,9 @@ func getNetworkFromArg(dc *downstreamConn, params []string) (*network, []string, type networkFlagSet struct { *flag.FlagSet - Addr, Name, Nick, Username, Pass, Realname *string - AutoAway, Enabled *bool - ConnectCommands []string + Addr, Name, Nick, Username, Pass, Realname, CertFP *string + AutoAway, Enabled *bool + ConnectCommands []string } func newNetworkFlagSet() *networkFlagSet { @@ -448,6 +448,7 @@ func newNetworkFlagSet() *networkFlagSet { fs.Var(stringPtrFlag{&fs.Username}, "username", "") fs.Var(stringPtrFlag{&fs.Pass}, "pass", "") fs.Var(stringPtrFlag{&fs.Realname}, "realname", "") + fs.Var(stringPtrFlag{&fs.CertFP}, "fingerprint", "") fs.Var(boolPtrFlag{&fs.AutoAway}, "auto-away", "") fs.Var(boolPtrFlag{&fs.Enabled}, "enabled", "") fs.Var((*stringSliceFlag)(&fs.ConnectCommands), "connect-command", "") @@ -484,6 +485,19 @@ func (fs *networkFlagSet) update(network *database.Network) error { if fs.Realname != nil { network.Realname = *fs.Realname } + if fs.CertFP != nil { + certFP := strings.ReplaceAll(*fs.CertFP, ":", "") + if _, err := hex.DecodeString(certFP); err != nil { + return fmt.Errorf("the certificate fingerprint must be hex-encoded") + } + if len(certFP) == 64 { + network.CertFP = "sha-256:" + certFP + } else if len(certFP) == 128 { + network.CertFP = "sha-512:" + certFP + } else { + return fmt.Errorf("the certificate fingerprint must be a SHA256 or SHA512 hash") + } + } if fs.AutoAway != nil { network.AutoAway = *fs.AutoAway } diff --git a/upstream.go b/upstream.go index 542d742..19e024a 100644 --- a/upstream.go +++ b/upstream.go @@ -4,9 +4,11 @@ import ( "context" "crypto" "crypto/sha256" + "crypto/sha512" "crypto/tls" "crypto/x509" "encoding/base64" + "encoding/hex" "errors" "fmt" "io" @@ -285,6 +287,40 @@ func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, er logger.Printf("using TLS client certificate %x", sha256.Sum256(network.SASL.External.CertBlob)) } + if network.CertFP != "" { + tlsConfig.InsecureSkipVerify = true + tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + if len(rawCerts) == 0 { + return fmt.Errorf("the server didn't present any TLS certificate") + } + + parts := strings.SplitN(network.CertFP, ":", 2) + algo, localCertFP := parts[0], parts[1] + + for _, rawCert := range rawCerts { + var remoteCertFP string + switch algo { + case "sha-512": + sum := sha512.Sum512(rawCert) + remoteCertFP = hex.EncodeToString(sum[:]) + case "sha-256": + sum := sha256.Sum256(rawCert) + remoteCertFP = hex.EncodeToString(sum[:]) + } + + if remoteCertFP == localCertFP { + return nil // fingerprints match + } + } + + // Fingerprints don't match, let's give the user a fingerprint + // they can use to connect + sum := sha512.Sum512(rawCerts[0]) + remoteCertFP := hex.EncodeToString(sum[:]) + return fmt.Errorf("the configured TLS certificate fingerprint doesn't match the server's - %s", remoteCertFP) + } + } + netConn, err = dialer.DialContext(ctx, "tcp", addr) if err != nil { return nil, fmt.Errorf("failed to dial %q: %v", addr, err)