OSDN Git Service

net: Check for EINTR.
[pf3gnuchains/gcc-fork.git] / libgo / go / net / fd.go
1 // Copyright 2009 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 // TODO(rsc): All the prints in this file should go to standard error.
6
7 package net
8
9 import (
10         "io"
11         "os"
12         "sync"
13         "syscall"
14         "time"
15 )
16
17 // Network file descriptor.
18 type netFD struct {
19         // locking/lifetime of sysfd
20         sysmu   sync.Mutex
21         sysref  int
22         closing bool
23
24         // immutable until Close
25         sysfd   int
26         family  int
27         proto   int
28         sysfile *os.File
29         cr      chan bool
30         cw      chan bool
31         net     string
32         laddr   Addr
33         raddr   Addr
34
35         // owned by client
36         rdeadline_delta int64
37         rdeadline       int64
38         rio             sync.Mutex
39         wdeadline_delta int64
40         wdeadline       int64
41         wio             sync.Mutex
42
43         // owned by fd wait server
44         ncr, ncw int
45 }
46
47 type InvalidConnError struct{}
48
49 func (e *InvalidConnError) String() string  { return "invalid net.Conn" }
50 func (e *InvalidConnError) Temporary() bool { return false }
51 func (e *InvalidConnError) Timeout() bool   { return false }
52
53 // A pollServer helps FDs determine when to retry a non-blocking
54 // read or write after they get EAGAIN.  When an FD needs to wait,
55 // send the fd on s.cr (for a read) or s.cw (for a write) to pass the
56 // request to the poll server.  Then receive on fd.cr/fd.cw.
57 // When the pollServer finds that i/o on FD should be possible
58 // again, it will send fd on fd.cr/fd.cw to wake any waiting processes.
59 // This protocol is implemented as s.WaitRead() and s.WaitWrite().
60 //
61 // There is one subtlety: when sending on s.cr/s.cw, the
62 // poll server is probably in a system call, waiting for an fd
63 // to become ready.  It's not looking at the request channels.
64 // To resolve this, the poll server waits not just on the FDs it has
65 // been given but also its own pipe.  After sending on the
66 // buffered channel s.cr/s.cw, WaitRead/WaitWrite writes a
67 // byte to the pipe, causing the pollServer's poll system call to
68 // return.  In response to the pipe being readable, the pollServer
69 // re-polls its request channels.
70 //
71 // Note that the ordering is "send request" and then "wake up server".
72 // If the operations were reversed, there would be a race: the poll
73 // server might wake up and look at the request channel, see that it
74 // was empty, and go back to sleep, all before the requester managed
75 // to send the request.  Because the send must complete before the wakeup,
76 // the request channel must be buffered.  A buffer of size 1 is sufficient
77 // for any request load.  If many processes are trying to submit requests,
78 // one will succeed, the pollServer will read the request, and then the
79 // channel will be empty for the next process's request.  A larger buffer
80 // might help batch requests.
81 //
82 // To avoid races in closing, all fd operations are locked and
83 // refcounted. when netFD.Close() is called, it calls syscall.Shutdown
84 // and sets a closing flag. Only when the last reference is removed
85 // will the fd be closed.
86
87 type pollServer struct {
88         cr, cw   chan *netFD // buffered >= 1
89         pr, pw   *os.File
90         pending  map[int]*netFD
91         poll     *pollster // low-level OS hooks
92         deadline int64     // next deadline (nsec since 1970)
93 }
94
95 func (s *pollServer) AddFD(fd *netFD, mode int) {
96         intfd := fd.sysfd
97         if intfd < 0 {
98                 // fd closed underfoot
99                 if mode == 'r' {
100                         fd.cr <- true
101                 } else {
102                         fd.cw <- true
103                 }
104                 return
105         }
106         if err := s.poll.AddFD(intfd, mode, false); err != nil {
107                 panic("pollServer AddFD " + err.String())
108                 return
109         }
110
111         var t int64
112         key := intfd << 1
113         if mode == 'r' {
114                 fd.ncr++
115                 t = fd.rdeadline
116         } else {
117                 fd.ncw++
118                 key++
119                 t = fd.wdeadline
120         }
121         s.pending[key] = fd
122         if t > 0 && (s.deadline == 0 || t < s.deadline) {
123                 s.deadline = t
124         }
125 }
126
127 func (s *pollServer) LookupFD(fd int, mode int) *netFD {
128         key := fd << 1
129         if mode == 'w' {
130                 key++
131         }
132         netfd, ok := s.pending[key]
133         if !ok {
134                 return nil
135         }
136         s.pending[key] = nil, false
137         return netfd
138 }
139
140 func (s *pollServer) WakeFD(fd *netFD, mode int) {
141         if mode == 'r' {
142                 for fd.ncr > 0 {
143                         fd.ncr--
144                         fd.cr <- true
145                 }
146         } else {
147                 for fd.ncw > 0 {
148                         fd.ncw--
149                         fd.cw <- true
150                 }
151         }
152 }
153
154 func (s *pollServer) Now() int64 {
155         return time.Nanoseconds()
156 }
157
158 func (s *pollServer) CheckDeadlines() {
159         now := s.Now()
160         // TODO(rsc): This will need to be handled more efficiently,
161         // probably with a heap indexed by wakeup time.
162
163         var next_deadline int64
164         for key, fd := range s.pending {
165                 var t int64
166                 var mode int
167                 if key&1 == 0 {
168                         mode = 'r'
169                 } else {
170                         mode = 'w'
171                 }
172                 if mode == 'r' {
173                         t = fd.rdeadline
174                 } else {
175                         t = fd.wdeadline
176                 }
177                 if t > 0 {
178                         if t <= now {
179                                 s.pending[key] = nil, false
180                                 if mode == 'r' {
181                                         s.poll.DelFD(fd.sysfd, mode)
182                                         fd.rdeadline = -1
183                                 } else {
184                                         s.poll.DelFD(fd.sysfd, mode)
185                                         fd.wdeadline = -1
186                                 }
187                                 s.WakeFD(fd, mode)
188                         } else if next_deadline == 0 || t < next_deadline {
189                                 next_deadline = t
190                         }
191                 }
192         }
193         s.deadline = next_deadline
194 }
195
196 func (s *pollServer) Run() {
197         var scratch [100]byte
198         for {
199                 var t = s.deadline
200                 if t > 0 {
201                         t = t - s.Now()
202                         if t <= 0 {
203                                 s.CheckDeadlines()
204                                 continue
205                         }
206                 }
207                 fd, mode, err := s.poll.WaitFD(t)
208                 if err != nil {
209                         print("pollServer WaitFD: ", err.String(), "\n")
210                         return
211                 }
212                 if fd < 0 {
213                         // Timeout happened.
214                         s.CheckDeadlines()
215                         continue
216                 }
217                 if fd == s.pr.Fd() {
218                         // Drain our wakeup pipe.
219                         for nn, _ := s.pr.Read(scratch[0:]); nn > 0; {
220                                 nn, _ = s.pr.Read(scratch[0:])
221                         }
222                         // Read from channels
223                         for fd, ok := <-s.cr; ok; fd, ok = <-s.cr {
224                                 s.AddFD(fd, 'r')
225                         }
226                         for fd, ok := <-s.cw; ok; fd, ok = <-s.cw {
227                                 s.AddFD(fd, 'w')
228                         }
229                 } else {
230                         netfd := s.LookupFD(fd, mode)
231                         if netfd == nil {
232                                 print("pollServer: unexpected wakeup for fd=", fd, " mode=", string(mode), "\n")
233                                 continue
234                         }
235                         s.WakeFD(netfd, mode)
236                 }
237         }
238 }
239
240 var wakeupbuf [1]byte
241
242 func (s *pollServer) Wakeup() { s.pw.Write(wakeupbuf[0:]) }
243
244 func (s *pollServer) WaitRead(fd *netFD) {
245         s.cr <- fd
246         s.Wakeup()
247         <-fd.cr
248 }
249
250 func (s *pollServer) WaitWrite(fd *netFD) {
251         s.cw <- fd
252         s.Wakeup()
253         <-fd.cw
254 }
255
256 // Network FD methods.
257 // All the network FDs use a single pollServer.
258
259 var pollserver *pollServer
260 var onceStartServer sync.Once
261
262 func startServer() {
263         p, err := newPollServer()
264         if err != nil {
265                 print("Start pollServer: ", err.String(), "\n")
266         }
267         pollserver = p
268 }
269
270 func newFD(fd, family, proto int, net string, laddr, raddr Addr) (f *netFD, err os.Error) {
271         onceStartServer.Do(startServer)
272         if e := syscall.SetNonblock(fd, true); e != 0 {
273                 return nil, &OpError{"setnonblock", net, laddr, os.Errno(e)}
274         }
275         f = &netFD{
276                 sysfd:  fd,
277                 family: family,
278                 proto:  proto,
279                 net:    net,
280                 laddr:  laddr,
281                 raddr:  raddr,
282         }
283         var ls, rs string
284         if laddr != nil {
285                 ls = laddr.String()
286         }
287         if raddr != nil {
288                 rs = raddr.String()
289         }
290         f.sysfile = os.NewFile(fd, net+":"+ls+"->"+rs)
291         f.cr = make(chan bool, 1)
292         f.cw = make(chan bool, 1)
293         return f, nil
294 }
295
296 // Add a reference to this fd.
297 func (fd *netFD) incref() {
298         fd.sysmu.Lock()
299         fd.sysref++
300         fd.sysmu.Unlock()
301 }
302
303 // Remove a reference to this FD and close if we've been asked to do so (and
304 // there are no references left.
305 func (fd *netFD) decref() {
306         fd.sysmu.Lock()
307         fd.sysref--
308         if fd.closing && fd.sysref == 0 && fd.sysfd >= 0 {
309                 // In case the user has set linger, switch to blocking mode so
310                 // the close blocks.  As long as this doesn't happen often, we
311                 // can handle the extra OS processes.  Otherwise we'll need to
312                 // use the pollserver for Close too.  Sigh.
313                 syscall.SetNonblock(fd.sysfd, false)
314                 fd.sysfile.Close()
315                 fd.sysfile = nil
316                 fd.sysfd = -1
317         }
318         fd.sysmu.Unlock()
319 }
320
321 func (fd *netFD) Close() os.Error {
322         if fd == nil || fd.sysfile == nil {
323                 return os.EINVAL
324         }
325
326         fd.incref()
327         syscall.Shutdown(fd.sysfd, syscall.SHUT_RDWR)
328         fd.closing = true
329         fd.decref()
330         return nil
331 }
332
333 func (fd *netFD) Read(p []byte) (n int, err os.Error) {
334         if fd == nil {
335                 return 0, os.EINVAL
336         }
337         fd.rio.Lock()
338         defer fd.rio.Unlock()
339         fd.incref()
340         defer fd.decref()
341         if fd.sysfile == nil {
342                 return 0, os.EINVAL
343         }
344         if fd.rdeadline_delta > 0 {
345                 fd.rdeadline = pollserver.Now() + fd.rdeadline_delta
346         } else {
347                 fd.rdeadline = 0
348         }
349         var oserr os.Error
350         for {
351                 var errno int
352                 n, errno = syscall.Read(fd.sysfile.Fd(), p)
353                 if (errno == syscall.EAGAIN || errno == syscall.EINTR) && fd.rdeadline >= 0 {
354                         pollserver.WaitRead(fd)
355                         continue
356                 }
357                 if errno != 0 {
358                         n = 0
359                         oserr = os.Errno(errno)
360                 } else if n == 0 && errno == 0 && fd.proto != syscall.SOCK_DGRAM {
361                         err = os.EOF
362                 }
363                 break
364         }
365         if oserr != nil {
366                 err = &OpError{"read", fd.net, fd.raddr, oserr}
367         }
368         return
369 }
370
371 func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err os.Error) {
372         if fd == nil || fd.sysfile == nil {
373                 return 0, nil, os.EINVAL
374         }
375         fd.rio.Lock()
376         defer fd.rio.Unlock()
377         fd.incref()
378         defer fd.decref()
379         if fd.rdeadline_delta > 0 {
380                 fd.rdeadline = pollserver.Now() + fd.rdeadline_delta
381         } else {
382                 fd.rdeadline = 0
383         }
384         var oserr os.Error
385         for {
386                 var errno int
387                 n, sa, errno = syscall.Recvfrom(fd.sysfd, p, 0)
388                 if (errno == syscall.EAGAIN || errno == syscall.EINTR) && fd.rdeadline >= 0 {
389                         pollserver.WaitRead(fd)
390                         continue
391                 }
392                 if errno != 0 {
393                         n = 0
394                         oserr = os.Errno(errno)
395                 }
396                 break
397         }
398         if oserr != nil {
399                 err = &OpError{"read", fd.net, fd.laddr, oserr}
400         }
401         return
402 }
403
404 func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err os.Error) {
405         if fd == nil || fd.sysfile == nil {
406                 return 0, 0, 0, nil, os.EINVAL
407         }
408         fd.rio.Lock()
409         defer fd.rio.Unlock()
410         fd.incref()
411         defer fd.decref()
412         if fd.rdeadline_delta > 0 {
413                 fd.rdeadline = pollserver.Now() + fd.rdeadline_delta
414         } else {
415                 fd.rdeadline = 0
416         }
417         var oserr os.Error
418         for {
419                 var errno int
420                 n, oobn, flags, sa, errno = syscall.Recvmsg(fd.sysfd, p, oob, 0)
421                 if (errno == syscall.EAGAIN || errno == syscall.EINTR) && fd.rdeadline >= 0 {
422                         pollserver.WaitRead(fd)
423                         continue
424                 }
425                 if errno != 0 {
426                         oserr = os.Errno(errno)
427                 }
428                 if n == 0 {
429                         oserr = os.EOF
430                 }
431                 break
432         }
433         if oserr != nil {
434                 err = &OpError{"read", fd.net, fd.laddr, oserr}
435                 return
436         }
437         return
438 }
439
440 func (fd *netFD) Write(p []byte) (n int, err os.Error) {
441         if fd == nil {
442                 return 0, os.EINVAL
443         }
444         fd.wio.Lock()
445         defer fd.wio.Unlock()
446         fd.incref()
447         defer fd.decref()
448         if fd.sysfile == nil {
449                 return 0, os.EINVAL
450         }
451         if fd.wdeadline_delta > 0 {
452                 fd.wdeadline = pollserver.Now() + fd.wdeadline_delta
453         } else {
454                 fd.wdeadline = 0
455         }
456         nn := 0
457         var oserr os.Error
458
459         for {
460                 n, errno := syscall.Write(fd.sysfile.Fd(), p[nn:])
461                 if n > 0 {
462                         nn += n
463                 }
464                 if nn == len(p) {
465                         break
466                 }
467                 if (errno == syscall.EAGAIN || errno == syscall.EINTR) && fd.wdeadline >= 0 {
468                         pollserver.WaitWrite(fd)
469                         continue
470                 }
471                 if errno != 0 {
472                         n = 0
473                         oserr = os.Errno(errno)
474                         break
475                 }
476                 if n == 0 {
477                         oserr = io.ErrUnexpectedEOF
478                         break
479                 }
480         }
481         if oserr != nil {
482                 err = &OpError{"write", fd.net, fd.raddr, oserr}
483         }
484         return nn, err
485 }
486
487 func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err os.Error) {
488         if fd == nil || fd.sysfile == nil {
489                 return 0, os.EINVAL
490         }
491         fd.wio.Lock()
492         defer fd.wio.Unlock()
493         fd.incref()
494         defer fd.decref()
495         if fd.wdeadline_delta > 0 {
496                 fd.wdeadline = pollserver.Now() + fd.wdeadline_delta
497         } else {
498                 fd.wdeadline = 0
499         }
500         var oserr os.Error
501         for {
502                 errno := syscall.Sendto(fd.sysfd, p, 0, sa)
503                 if (errno == syscall.EAGAIN || errno == syscall.EINTR) && fd.wdeadline >= 0 {
504                         pollserver.WaitWrite(fd)
505                         continue
506                 }
507                 if errno != 0 {
508                         oserr = os.Errno(errno)
509                 }
510                 break
511         }
512         if oserr == nil {
513                 n = len(p)
514         } else {
515                 err = &OpError{"write", fd.net, fd.raddr, oserr}
516         }
517         return
518 }
519
520 func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err os.Error) {
521         if fd == nil || fd.sysfile == nil {
522                 return 0, 0, os.EINVAL
523         }
524         fd.wio.Lock()
525         defer fd.wio.Unlock()
526         fd.incref()
527         defer fd.decref()
528         if fd.wdeadline_delta > 0 {
529                 fd.wdeadline = pollserver.Now() + fd.wdeadline_delta
530         } else {
531                 fd.wdeadline = 0
532         }
533         var oserr os.Error
534         for {
535                 var errno int
536                 errno = syscall.Sendmsg(fd.sysfd, p, oob, sa, 0)
537                 if (errno == syscall.EAGAIN || errno == syscall.EINTR) && fd.wdeadline >= 0 {
538                         pollserver.WaitWrite(fd)
539                         continue
540                 }
541                 if errno != 0 {
542                         oserr = os.Errno(errno)
543                 }
544                 break
545         }
546         if oserr == nil {
547                 n = len(p)
548                 oobn = len(oob)
549         } else {
550                 err = &OpError{"write", fd.net, fd.raddr, oserr}
551         }
552         return
553 }
554
555 func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err os.Error) {
556         if fd == nil || fd.sysfile == nil {
557                 return nil, os.EINVAL
558         }
559
560         fd.incref()
561         defer fd.decref()
562
563         // See ../syscall/exec.go for description of ForkLock.
564         // It is okay to hold the lock across syscall.Accept
565         // because we have put fd.sysfd into non-blocking mode.
566         syscall.ForkLock.RLock()
567         var s, e int
568         var sa syscall.Sockaddr
569         for {
570                 if fd.closing {
571                         syscall.ForkLock.RUnlock()
572                         return nil, os.EINVAL
573                 }
574                 s, sa, e = syscall.Accept(fd.sysfd)
575                 if e != syscall.EAGAIN && e != syscall.EINTR {
576                         break
577                 }
578                 syscall.ForkLock.RUnlock()
579                 pollserver.WaitRead(fd)
580                 syscall.ForkLock.RLock()
581         }
582         if e != 0 {
583                 syscall.ForkLock.RUnlock()
584                 return nil, &OpError{"accept", fd.net, fd.laddr, os.Errno(e)}
585         }
586         syscall.CloseOnExec(s)
587         syscall.ForkLock.RUnlock()
588
589         if nfd, err = newFD(s, fd.family, fd.proto, fd.net, fd.laddr, toAddr(sa)); err != nil {
590                 syscall.Close(s)
591                 return nil, err
592         }
593         return nfd, nil
594 }
595
596 func (fd *netFD) dup() (f *os.File, err os.Error) {
597         ns, e := syscall.Dup(fd.sysfd)
598         if e != 0 {
599                 return nil, &OpError{"dup", fd.net, fd.laddr, os.Errno(e)}
600         }
601
602         // We want blocking mode for the new fd, hence the double negative.
603         if e = syscall.SetNonblock(ns, false); e != 0 {
604                 return nil, &OpError{"setnonblock", fd.net, fd.laddr, os.Errno(e)}
605         }
606
607         return os.NewFile(ns, fd.sysfile.Name()), nil
608 }
609
610 func closesocket(s int) (errno int) {
611         return syscall.Close(s)
612 }