diff options
Diffstat (limited to 'src/os/user/userdbclient_linux.go')
| -rw-r--r-- | src/os/user/userdbclient_linux.go | 772 |
1 files changed, 772 insertions, 0 deletions
diff --git a/src/os/user/userdbclient_linux.go b/src/os/user/userdbclient_linux.go new file mode 100644 index 0000000000..e585b7f3c3 --- /dev/null +++ b/src/os/user/userdbclient_linux.go @@ -0,0 +1,772 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build linux + +package user + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "io/fs" + "os" + "strconv" + "strings" + "sync" + "syscall" + "unicode/utf16" + "unicode/utf8" +) + +const ( + // Well known multiplexer service. + svcMultiplexer = "io.systemd.Multiplexer" + + userdbNamespace = "io.systemd.UserDatabase" + + // io.systemd.UserDatabase VARLINK interface methods. + mGetGroupRecord = userdbNamespace + ".GetGroupRecord" + mGetUserRecord = userdbNamespace + ".GetUserRecord" + mGetMemberships = userdbNamespace + ".GetMemberships" + + // io.systemd.UserDatabase VARLINK interface errors. + errNoRecordFound = userdbNamespace + ".NoRecordFound" + errServiceNotAvailable = userdbNamespace + ".ServiceNotAvailable" +) + +func init() { + defaultUserdbClient.dir = "/run/systemd/userdb" +} + +// userdbCall represents a VARLINK service call sent to systemd-userdb. +// method is the VARLINK method to call. +// parameters are the VARLINK parameters to pass. +// more indicates if more responses are expected. +// fastest indicates if only the fastest response should be returned. +type userdbCall struct { + method string + parameters callParameters + more bool + fastest bool +} + +func (u userdbCall) marshalJSON(service string) ([]byte, error) { + params, err := u.parameters.marshalJSON(service) + if err != nil { + return nil, err + } + var data bytes.Buffer + data.WriteString(`{"method":"`) + data.WriteString(u.method) + data.WriteString(`","parameters":`) + data.Write(params) + if u.more { + data.WriteString(`,"more":true`) + } + data.WriteString(`}`) + return data.Bytes(), nil +} + +type callParameters struct { + uid *int64 + userName string + gid *int64 + groupName string +} + +func (c callParameters) marshalJSON(service string) ([]byte, error) { + var data bytes.Buffer + data.WriteString(`{"service":"`) + data.WriteString(service) + data.WriteString(`"`) + if c.uid != nil { + data.WriteString(`,"uid":`) + data.WriteString(strconv.FormatInt(*c.uid, 10)) + } + if c.userName != "" { + data.WriteString(`,"userName":"`) + data.WriteString(c.userName) + data.WriteString(`"`) + } + if c.gid != nil { + data.WriteString(`,"gid":`) + data.WriteString(strconv.FormatInt(*c.gid, 10)) + } + if c.groupName != "" { + data.WriteString(`,"groupName":"`) + data.WriteString(c.groupName) + data.WriteString(`"`) + } + data.WriteString(`}`) + return data.Bytes(), nil +} + +type userdbReply struct { + continues bool + errorStr string +} + +func (u *userdbReply) unmarshalJSON(data []byte) error { + var ( + kContinues = []byte(`"continues"`) + kError = []byte(`"error"`) + ) + if i := bytes.Index(data, kContinues); i != -1 { + continues, err := parseJSONBoolean(data[i+len(kContinues):]) + if err != nil { + return err + } + u.continues = continues + } + if i := bytes.Index(data, kError); i != -1 { + errStr, err := parseJSONString(data[i+len(kError):]) + if err != nil { + return err + } + u.errorStr = errStr + } + return nil +} + +// response is the parsed reply from a method call to systemd-userdb. +// data is one or more VARLINK response parameters separated by 0. +// handled indicates if the call was handled by systemd-userdb. +// err is any error encountered. +type response struct { + data []byte + handled bool + err error +} + +// querySocket calls the io.systemd.UserDatabase VARLINK interface at sock with request. +// Multiple replies can be read by setting more to true in the request. +// Reply parameters are accumulated separated by 0, if there are many. +// Replies with io.systemd.UserDatabase.NoRecordFound errors are skipped. +// Other UserDatabase errors are returned as is. +// If the socket does not exist, or if the io.systemd.UserDatabase.ServiceNotAvailable +// error is seen in a response, the query is considered unhandled. +func querySocket(ctx context.Context, sock string, request []byte) response { + sockFd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) + if err != nil { + return response{err: err} + } + defer syscall.Close(sockFd) + if err := syscall.Connect(sockFd, &syscall.SockaddrUnix{Name: sock}); err != nil { + if errors.Is(err, os.ErrNotExist) { + return response{err: err} + } + return response{handled: true, err: err} + } + + // Null terminate request. + if request[len(request)-1] != 0 { + request = append(request, 0) + } + + // Write request to socket. + written := 0 + for written < len(request) { + if ctx.Err() != nil { + return response{handled: true, err: ctx.Err()} + } + if n, err := syscall.Write(sockFd, request[written:]); err != nil { + return response{handled: true, err: err} + } else { + written += n + } + } + + // Read response. + var resp bytes.Buffer + for { + if ctx.Err() != nil { + return response{handled: true, err: ctx.Err()} + } + buf := make([]byte, 4096) + if n, err := syscall.Read(sockFd, buf); err != nil { + return response{handled: true, err: err} + } else if n > 0 { + resp.Write(buf[:n]) + if buf[n-1] == 0 { + break + } + } else { + // EOF + break + } + } + + if resp.Len() == 0 { + return response{handled: true} + } + + buf := resp.Bytes() + // Remove trailing 0. + buf = buf[:len(buf)-1] + // Split into VARLINK messages. + msgs := bytes.Split(buf, []byte{0}) + + // Parse VARLINK messages. + for _, m := range msgs { + var resp userdbReply + if err := resp.unmarshalJSON(m); err != nil { + return response{handled: true, err: err} + } + // Handle VARLINK message errors. + switch e := resp.errorStr; e { + case "": + case errNoRecordFound: // Ignore not found error. + continue + case errServiceNotAvailable: + return response{} + default: + return response{handled: true, err: errors.New(e)} + } + if !resp.continues { + break + } + } + return response{data: buf, handled: true, err: ctx.Err()} +} + +// queryMany calls the io.systemd.UserDatabase VARLINK interface on many services at once. +// ss is a slice of userdb services to call. Each service must have a socket in cl.dir. +// c is sent to all services in ss. If c.fastest is true, only the fastest reply is read. +// Otherwise all replies are aggregated. um is called with aggregated reply parameters. +// queryMany returns the first error encountered. The first result is false if no userdb +// socket is available or if all requests time out. +func (cl userdbClient) queryMany(ctx context.Context, ss []string, c *userdbCall, um jsonUnmarshaler) (bool, error) { + responseCh := make(chan response, len(ss)) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // Query all services in parallel. + var workers sync.WaitGroup + for _, svc := range ss { + data, err := c.marshalJSON(svc) + if err != nil { + return true, err + } + // Spawn worker to query service. + workers.Add(1) + go func(sock string, data []byte) { + defer workers.Done() + responseCh <- querySocket(ctx, sock, data) + }(cl.dir+"/"+svc, data) + } + + go func() { + // Clean up workers. + workers.Wait() + close(responseCh) + }() + + var result bytes.Buffer + var notOk int +RecvResponses: + for { + select { + case resp, ok := <-responseCh: + if !ok { + // Responses channel is closed so stop reading. + break RecvResponses + } + if resp.err != nil { + // querySocket only returns unrecoverable errors, + // so return the first one received. + return true, resp.err + } + if !resp.handled { + notOk++ + continue + } + + first := result.Len() == 0 + result.Write(resp.data) + if first && c.fastest { + // Return the fastest response. + break RecvResponses + } + case <-ctx.Done(): + // If requests time out, userdb is unavailable. + return ctx.Err() != context.DeadlineExceeded, nil + } + } + // If all sockets are not ok, userdb is unavailable. + if notOk == len(ss) { + return false, nil + } + return true, um.unmarshalJSON(result.Bytes()) +} + +// services enumerates userdb service sockets in dir. +// If ok is false, io.systemd.UserDatabase service does not exist. +func (cl userdbClient) services() (s []string, ok bool, err error) { + var entries []fs.DirEntry + if entries, err = os.ReadDir(cl.dir); err != nil { + ok = !os.IsNotExist(err) + return + } + ok = true + for _, ent := range entries { + s = append(s, ent.Name()) + } + return +} + +// query looks up users/groups on the io.systemd.UserDatabase VARLINK interface. +// If the multiplexer service is available, the call is sent only to it. +// Otherwise, the call is sent simultaneously to all UserDatabase services in cl.dir. +// The fastest reply is read and parsed. All other requests are cancelled. +// If the service is unavailable, the first result is false. +// The service is considered unavailable if the requests time-out as well. +func (cl userdbClient) query(ctx context.Context, call *userdbCall, um jsonUnmarshaler) (bool, error) { + services := []string{svcMultiplexer} + if _, err := os.Stat(cl.dir + "/" + svcMultiplexer); err != nil { + // No mux service so call all available services. + var ok bool + if services, ok, err = cl.services(); !ok || err != nil { + return ok, err + } + } + call.fastest = true + if ok, err := cl.queryMany(ctx, services, call, um); !ok || err != nil { + return ok, err + } + return true, nil +} + +type jsonUnmarshaler interface { + unmarshalJSON([]byte) error +} + +func isSpace(c byte) bool { + return c == ' ' || c == '\t' || c == '\r' || c == '\n' +} + +// findElementStart returns a slice of r that starts at the next JSON element. +// It skips over valid JSON space characters and checks for the colon separator. +func findElementStart(r []byte) ([]byte, error) { + var idx int + var b byte + colon := byte(':') + var seenColon bool + for idx, b = range r { + if isSpace(b) { + continue + } + if !seenColon && b == colon { + seenColon = true + continue + } + // Spotted colon and b is not a space, so value starts here. + if seenColon { + break + } + return nil, errors.New("expected colon, got invalid character: " + string(b)) + } + if !seenColon { + return nil, errors.New("expected colon, got end of input") + } + return r[idx:], nil +} + +// parseJSONString reads a JSON string from r. +func parseJSONString(r []byte) (string, error) { + r, err := findElementStart(r) + if err != nil { + return "", err + } + // Smallest valid string is `""`. + if l := len(r); l < 2 { + return "", errors.New("unexpected end of input") + } else if l == 2 { + if bytes.Equal(r, []byte(`""`)) { + return "", nil + } + return "", errors.New("invalid string") + } + + if c := r[0]; c != '"' { + return "", errors.New(`expected " got ` + string(c)) + } + // Advance over opening quote. + r = r[1:] + + var value strings.Builder + var inEsc bool + var inUEsc bool + var strEnds bool + reader := bytes.NewReader(r) + for { + if value.Len() > 4096 { + return "", errors.New("string too large") + } + + // Parse unicode escape sequences. + if inUEsc { + maybeRune := make([]byte, 4) + n, err := reader.Read(maybeRune) + if err != nil || n != 4 { + return "", fmt.Errorf("invalid unicode escape sequence \\u%s", string(maybeRune)) + } + prn, err := strconv.ParseUint(string(maybeRune), 16, 32) + if err != nil { + return "", fmt.Errorf("invalid unicode escape sequence \\u%s", string(maybeRune)) + } + rn := rune(prn) + if !utf16.IsSurrogate(rn) { + value.WriteRune(rn) + inUEsc = false + continue + } + // rn maybe a high surrogate; read the low surrogate. + maybeRune = make([]byte, 6) + n, err = reader.Read(maybeRune) + if err != nil || n != 6 || maybeRune[0] != '\\' || maybeRune[1] != 'u' { + // Not a valid UTF-16 surrogate pair. + if _, err := reader.Seek(int64(-n), io.SeekCurrent); err != nil { + return "", err + } + // Invalid low surrogate; write the replacement character. + value.WriteRune(utf8.RuneError) + } else { + rn1, err := strconv.ParseUint(string(maybeRune[2:]), 16, 32) + if err != nil { + return "", fmt.Errorf("invalid unicode escape sequence %s", string(maybeRune)) + } + // Check if rn and rn1 are valid UTF-16 surrogate pairs. + if dec := utf16.DecodeRune(rn, rune(rn1)); dec != utf8.RuneError { + n = utf8.EncodeRune(maybeRune, dec) + // Write the decoded rune. + value.Write(maybeRune[:n]) + } + } + inUEsc = false + continue + } + + if inEsc { + b, err := reader.ReadByte() + if err != nil { + return "", err + } + switch b { + case 'b': + value.WriteByte('\b') + case 'f': + value.WriteByte('\f') + case 'n': + value.WriteByte('\n') + case 'r': + value.WriteByte('\r') + case 't': + value.WriteByte('\t') + case 'u': + inUEsc = true + case '/': + value.WriteByte('/') + case '\\': + value.WriteByte('\\') + case '"': + value.WriteByte('"') + default: + return "", errors.New("unexpected character in escape sequence " + string(b)) + } + inEsc = false + continue + } else { + rn, _, err := reader.ReadRune() + if err != nil { + if err == io.EOF { + break + } + return "", err + } + if rn == '\\' { + inEsc = true + continue + } + if rn == '"' { + // String ends on un-escaped quote. + strEnds = true + break + } + value.WriteRune(rn) + } + } + if !strEnds { + return "", errors.New("unexpected end of input") + } + return value.String(), nil +} + +// parseJSONInt64 reads a 64 bit integer from r. +func parseJSONInt64(r []byte) (int64, error) { + r, err := findElementStart(r) + if err != nil { + return 0, err + } + var num strings.Builder + for _, b := range r { + // int64 max is 19 digits long. + if num.Len() == 20 { + return 0, errors.New("number too large") + } + if strings.ContainsRune("0123456789", rune(b)) { + num.WriteByte(b) + } else { + break + } + } + n, err := strconv.ParseInt(num.String(), 10, 64) + return int64(n), err +} + +// parseJSONBoolean reads a boolean from r. +func parseJSONBoolean(r []byte) (bool, error) { + r, err := findElementStart(r) + if err != nil { + return false, err + } + if bytes.HasPrefix(r, []byte("true")) { + return true, nil + } + if bytes.HasPrefix(r, []byte("false")) { + return false, nil + } + return false, errors.New("unable to parse boolean value") +} + +type groupRecord struct { + groupName string + gid int64 +} + +func (g *groupRecord) unmarshalJSON(data []byte) error { + var ( + kGroupName = []byte(`"groupName"`) + kGid = []byte(`"gid"`) + ) + if i := bytes.Index(data, kGroupName); i != -1 { + groupname, err := parseJSONString(data[i+len(kGroupName):]) + if err != nil { + return err + } + g.groupName = groupname + } + if i := bytes.Index(data, kGid); i != -1 { + gid, err := parseJSONInt64(data[i+len(kGid):]) + if err != nil { + return err + } + g.gid = gid + } + return nil +} + +// queryGroupDb queries the userdb interface for a gid, groupname, or both. +func (cl userdbClient) queryGroupDb(ctx context.Context, gid *int64, groupname string) (*Group, bool, error) { + group := groupRecord{} + request := userdbCall{ + method: mGetGroupRecord, + parameters: callParameters{gid: gid, groupName: groupname}, + } + if ok, err := cl.query(ctx, &request, &group); !ok || err != nil { + return nil, ok, fmt.Errorf("error querying systemd-userdb group record: %s", err) + } + return &Group{ + Name: group.groupName, + Gid: strconv.FormatInt(group.gid, 10), + }, true, nil +} + +type userRecord struct { + userName string + realName string + uid int64 + gid int64 + homeDirectory string +} + +func (u *userRecord) unmarshalJSON(data []byte) error { + var ( + kUserName = []byte(`"userName"`) + kRealName = []byte(`"realName"`) + kUid = []byte(`"uid"`) + kGid = []byte(`"gid"`) + kHomeDirectory = []byte(`"homeDirectory"`) + ) + if i := bytes.Index(data, kUserName); i != -1 { + username, err := parseJSONString(data[i+len(kUserName):]) + if err != nil { + return err + } + u.userName = username + } + if i := bytes.Index(data, kRealName); i != -1 { + realname, err := parseJSONString(data[i+len(kRealName):]) + if err != nil { + return err + } + u.realName = realname + } + if i := bytes.Index(data, kUid); i != -1 { + uid, err := parseJSONInt64(data[i+len(kUid):]) + if err != nil { + return err + } + u.uid = uid + } + if i := bytes.Index(data, kGid); i != -1 { + gid, err := parseJSONInt64(data[i+len(kGid):]) + if err != nil { + return err + } + u.gid = gid + } + if i := bytes.Index(data, kHomeDirectory); i != -1 { + homedir, err := parseJSONString(data[i+len(kHomeDirectory):]) + if err != nil { + return err + } + u.homeDirectory = homedir + } + return nil +} + +// queryUserDb queries the userdb interface for a uid, username, or both. +func (cl userdbClient) queryUserDb(ctx context.Context, uid *int64, username string) (*User, bool, error) { + user := userRecord{} + request := userdbCall{ + method: mGetUserRecord, + parameters: callParameters{ + uid: uid, + userName: username, + }, + } + if ok, err := cl.query(ctx, &request, &user); !ok || err != nil { + return nil, ok, fmt.Errorf("error querying systemd-userdb user record: %s", err) + } + return &User{ + Uid: strconv.FormatInt(user.uid, 10), + Gid: strconv.FormatInt(user.gid, 10), + Username: user.userName, + Name: user.realName, + HomeDir: user.homeDirectory, + }, true, nil +} + +func (cl userdbClient) lookupGroup(ctx context.Context, groupname string) (*Group, bool, error) { + return cl.queryGroupDb(ctx, nil, groupname) +} + +func (cl userdbClient) lookupGroupId(ctx context.Context, id string) (*Group, bool, error) { + gid, err := strconv.ParseInt(id, 10, 64) + if err != nil { + return nil, true, err + } + return cl.queryGroupDb(ctx, &gid, "") +} + +func (cl userdbClient) lookupUser(ctx context.Context, username string) (*User, bool, error) { + return cl.queryUserDb(ctx, nil, username) +} + +func (cl userdbClient) lookupUserId(ctx context.Context, id string) (*User, bool, error) { + uid, err := strconv.ParseInt(id, 10, 64) + if err != nil { + return nil, true, err + } + return cl.queryUserDb(ctx, &uid, "") +} + +type memberships struct { + // Keys are groupNames and values are sets of userNames. + groupUsers map[string]map[string]struct{} +} + +// unmarshalJSON expects many (userName, groupName) records separated by a null byte. +// This is used to build a membership map. +func (m *memberships) unmarshalJSON(data []byte) error { + if m.groupUsers == nil { + m.groupUsers = make(map[string]map[string]struct{}) + } + var ( + kUserName = []byte(`"userName"`) + kGroupName = []byte(`"groupName"`) + ) + // Split records by null terminator. + records := bytes.Split(data, []byte{byte(0)}) + for _, rec := range records { + if len(rec) == 0 { + continue + } + var groupName string + var userName string + var err error + if i := bytes.Index(rec, kGroupName); i != -1 { + if groupName, err = parseJSONString(rec[i+len(kGroupName):]); err != nil { + return err + } + } + if i := bytes.Index(rec, kUserName); i != -1 { + if userName, err = parseJSONString(rec[i+len(kUserName):]); err != nil { + return err + } + } + // Associate userName with groupName. + if groupName != "" && userName != "" { + if _, ok := m.groupUsers[groupName]; ok { + m.groupUsers[groupName][userName] = struct{}{} + } else { + m.groupUsers[groupName] = map[string]struct{}{userName: {}} + } + } + } + return nil +} + +func (cl userdbClient) lookupGroupIds(ctx context.Context, username string) ([]string, bool, error) { + services, ok, err := cl.services() + if !ok || err != nil { + return nil, ok, err + } + // Fetch group memberships for username. + var ms memberships + request := userdbCall{ + method: mGetMemberships, + parameters: callParameters{userName: username}, + more: true, + } + if ok, err := cl.queryMany(ctx, services, &request, &ms); !ok || err != nil { + return nil, ok, fmt.Errorf("error querying systemd-userdb memberships record: %s", err) + } + // Fetch user group gid. + var group groupRecord + request = userdbCall{ + method: mGetGroupRecord, + parameters: callParameters{groupName: username}, + } + if ok, err := cl.query(ctx, &request, &group); !ok || err != nil { + return nil, ok, err + } + gids := []string{strconv.FormatInt(group.gid, 10)} + + // Fetch group records for each group. + for g := range ms.groupUsers { + var group groupRecord + request.parameters.groupName = g + // Query group for gid. + if ok, err := cl.query(ctx, &request, &group); !ok || err != nil { + return nil, ok, fmt.Errorf("error querying systemd-userdb group record: %s", err) + } + gids = append(gids, strconv.FormatInt(group.gid, 10)) + } + return gids, true, nil +} |
