blob: 6c24968b8df3a915ea10fbeae6d5a306015bc918 [file] [log] [blame]
Dan Willemsen18490112018-05-25 16:30:04 -07001// Copyright 2018 Google Inc. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package paths
16
17import (
18 "context"
19 "encoding/gob"
20 "fmt"
21 "io/ioutil"
22 "net"
23 "os"
24 "path/filepath"
25 "runtime"
26 "sync"
27 "syscall"
28 "time"
29)
30
31type LogProcess struct {
32 Pid int
33 Command string
34}
35
36type LogEntry struct {
37 Basename string
38 Args []string
39 Parents []LogProcess
40}
41
42const timeoutDuration = time.Duration(100) * time.Millisecond
43
44type socketAddrFunc func(string) (string, func(), error)
45
46func procFallback(name string) (string, func(), error) {
47 d, err := os.Open(filepath.Dir(name))
48 if err != nil {
49 return "", func() {}, err
50 }
51
52 return fmt.Sprintf("/proc/self/fd/%d/%s", d.Fd(), filepath.Base(name)), func() {
53 d.Close()
54 }, nil
55}
56
57func tmpFallback(name string) (addr string, cleanup func(), err error) {
58 d, err := ioutil.TempDir("/tmp", "log_sock")
59 if err != nil {
60 cleanup = func() {}
61 return
62 }
63 cleanup = func() {
64 os.RemoveAll(d)
65 }
66
67 dir := filepath.Dir(name)
68
69 absDir, err := filepath.Abs(dir)
70 if err != nil {
71 return
72 }
73
74 err = os.Symlink(absDir, filepath.Join(d, "d"))
75 if err != nil {
76 return
77 }
78
79 addr = filepath.Join(d, "d", filepath.Base(name))
80
81 return
82}
83
84func getSocketAddr(name string) (string, func(), error) {
85 maxNameLen := len(syscall.RawSockaddrUnix{}.Path)
86
87 if len(name) < maxNameLen {
88 return name, func() {}, nil
89 }
90
91 if runtime.GOOS == "linux" {
92 addr, cleanup, err := procFallback(name)
93 if err == nil {
94 if len(addr) < maxNameLen {
95 return addr, cleanup, nil
96 }
97 }
98 cleanup()
99 }
100
101 addr, cleanup, err := tmpFallback(name)
102 if err == nil {
103 if len(addr) < maxNameLen {
104 return addr, cleanup, nil
105 }
106 }
107 cleanup()
108
109 return name, func() {}, fmt.Errorf("Path to socket is still over size limit, fallbacks failed.")
110}
111
112func dial(name string, lookup socketAddrFunc, timeout time.Duration) (net.Conn, error) {
113 socket, cleanup, err := lookup(name)
114 defer cleanup()
115 if err != nil {
116 return nil, err
117 }
118
119 dialer := &net.Dialer{
120 Timeout: timeout,
121 }
122 return dialer.Dial("unix", socket)
123}
124
125func listen(name string, lookup socketAddrFunc) (net.Listener, error) {
126 socket, cleanup, err := lookup(name)
127 defer cleanup()
128 if err != nil {
129 return nil, err
130 }
131
132 return net.Listen("unix", socket)
133}
134
135func SendLog(logSocket string, entry *LogEntry, done chan interface{}) {
136 sendLog(logSocket, getSocketAddr, timeoutDuration, entry, done)
137}
138
139func sendLog(logSocket string, lookup socketAddrFunc, timeout time.Duration, entry *LogEntry, done chan interface{}) {
140 defer close(done)
141
142 conn, err := dial(logSocket, lookup, timeout)
143 if err != nil {
144 return
145 }
146 defer conn.Close()
147
148 if timeout != 0 {
149 conn.SetDeadline(time.Now().Add(timeout))
150 }
151
152 enc := gob.NewEncoder(conn)
153 enc.Encode(entry)
154}
155
156func LogListener(ctx context.Context, logSocket string) (chan *LogEntry, error) {
157 return logListener(ctx, logSocket, getSocketAddr)
158}
159
160func logListener(ctx context.Context, logSocket string, lookup socketAddrFunc) (chan *LogEntry, error) {
161 ret := make(chan *LogEntry, 5)
162
163 if err := os.Remove(logSocket); err != nil && !os.IsNotExist(err) {
164 return nil, err
165 }
166
167 ln, err := listen(logSocket, lookup)
168 if err != nil {
169 return nil, err
170 }
171
172 go func() {
173 for {
174 select {
175 case <-ctx.Done():
176 ln.Close()
177 }
178 }
179 }()
180
181 go func() {
182 var wg sync.WaitGroup
183 defer func() {
184 wg.Wait()
185 close(ret)
186 }()
187
188 for {
189 conn, err := ln.Accept()
190 if err != nil {
191 ln.Close()
192 break
193 }
194 conn.SetDeadline(time.Now().Add(timeoutDuration))
195 wg.Add(1)
196
197 go func() {
198 defer wg.Done()
199 defer conn.Close()
200
201 dec := gob.NewDecoder(conn)
202 entry := &LogEntry{}
203 if err := dec.Decode(entry); err != nil {
204 return
205 }
206 ret <- entry
207 }()
208 }
209 }()
210 return ret, nil
211}