Coverage for rpc.py: 19%

437 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-11 13:22 -0700

1"""RPC Implementation, originally written for the Python Idle IDE 

2 

3For security reasons, GvR requested that Idle's Python execution server process 

4connect to the Idle process, which listens for the connection. Since Idle has 

5only one client per server, this was not a limitation. 

6 

7 +---------------------------------+ +-------------+ 

8 | socketserver.BaseRequestHandler | | SocketIO | 

9 +---------------------------------+ +-------------+ 

10 ^ | register() | 

11 | | unregister()| 

12 | +-------------+ 

13 | ^ ^ 

14 | | | 

15 | + -------------------+ | 

16 | | | 

17 +-------------------------+ +-----------------+ 

18 | RPCHandler | | RPCClient | 

19 | [attribute of RPCServer]| | | 

20 +-------------------------+ +-----------------+ 

21 

22The RPCServer handler class is expected to provide register/unregister methods. 

23RPCHandler inherits the mix-in class SocketIO, which provides these methods. 

24 

25See the Idle run.main() docstring for further information on how this was 

26accomplished in Idle. 

27 

28""" 

29import builtins 

30import copyreg 

31import io 

32import marshal 

33import os 

34import pickle 

35import queue 

36import select 

37import socket 

38import socketserver 

39import struct 

40import sys 

41import threading 

42import traceback 

43import types 

44 

45def unpickle_code(ms): 

46 "Return code object from marshal string ms." 

47 co = marshal.loads(ms) 1c

48 assert isinstance(co, types.CodeType) 1c

49 return co 1c

50 

51def pickle_code(co): 

52 "Return unpickle function and tuple with marshalled co code object." 

53 assert isinstance(co, types.CodeType) 1bc

54 ms = marshal.dumps(co) 1bc

55 return unpickle_code, (ms,) 1bc

56 

57def dumps(obj, protocol=None): 

58 "Return pickled (or marshalled) string for obj." 

59 # IDLE passes 'None' to select pickle.DEFAULT_PROTOCOL. 

60 f = io.BytesIO() 1b

61 p = CodePickler(f, protocol) 1b

62 p.dump(obj) 1b

63 return f.getvalue() 1b

64 

65 

66class CodePickler(pickle.Pickler): 

67 dispatch_table = {types.CodeType: pickle_code, **copyreg.dispatch_table} 

68 

69 

70BUFSIZE = 8*1024 

71LOCALHOST = '127.0.0.1' 

72 

73class RPCServer(socketserver.TCPServer): 

74 

75 def __init__(self, addr, handlerclass=None): 

76 if handlerclass is None: 

77 handlerclass = RPCHandler 

78 socketserver.TCPServer.__init__(self, addr, handlerclass) 

79 

80 def server_bind(self): 

81 "Override TCPServer method, no bind() phase for connecting entity" 

82 pass 

83 

84 def server_activate(self): 

85 """Override TCPServer method, connect() instead of listen() 

86 

87 Due to the reversed connection, self.server_address is actually the 

88 address of the Idle Client to which we are connecting. 

89 

90 """ 

91 self.socket.connect(self.server_address) 

92 

93 def get_request(self): 

94 "Override TCPServer method, return already connected socket" 

95 return self.socket, self.server_address 

96 

97 def handle_error(self, request, client_address): 

98 """Override TCPServer method 

99 

100 Error message goes to __stderr__. No error message if exiting 

101 normally or socket raised EOF. Other exceptions not handled in 

102 server code will cause os._exit. 

103 

104 """ 

105 try: 

106 raise 

107 except SystemExit: 

108 raise 

109 except: 

110 erf = sys.__stderr__ 

111 print('\n' + '-'*40, file=erf) 

112 print('Unhandled server exception!', file=erf) 

113 print('Thread: %s' % threading.current_thread().name, file=erf) 

114 print('Client Address: ', client_address, file=erf) 

115 print('Request: ', repr(request), file=erf) 

116 traceback.print_exc(file=erf) 

117 print('\n*** Unrecoverable, server exiting!', file=erf) 

118 print('-'*40, file=erf) 

119 os._exit(0) 

120 

121#----------------- end class RPCServer -------------------- 

122 

123objecttable = {} 

124request_queue = queue.Queue(0) 

125response_queue = queue.Queue(0) 

126 

127 

128class SocketIO: 

129 

130 nextseq = 0 

131 

132 def __init__(self, sock, objtable=None, debugging=None): 

133 self.sockthread = threading.current_thread() 

134 if debugging is not None: 

135 self.debugging = debugging 

136 self.sock = sock 

137 if objtable is None: 

138 objtable = objecttable 

139 self.objtable = objtable 

140 self.responses = {} 

141 self.cvars = {} 

142 

143 def close(self): 

144 sock = self.sock 

145 self.sock = None 

146 if sock is not None: 

147 sock.close() 

148 

149 def exithook(self): 

150 "override for specific exit action" 

151 os._exit(0) 

152 

153 def debug(self, *args): 

154 if not self.debugging: 

155 return 

156 s = self.location + " " + str(threading.current_thread().name) 

157 for a in args: 

158 s = s + " " + str(a) 

159 print(s, file=sys.__stderr__) 

160 

161 def register(self, oid, object): 

162 self.objtable[oid] = object 

163 

164 def unregister(self, oid): 

165 try: 

166 del self.objtable[oid] 

167 except KeyError: 

168 pass 

169 

170 def localcall(self, seq, request): 

171 self.debug("localcall:", request) 

172 try: 

173 how, (oid, methodname, args, kwargs) = request 

174 except TypeError: 

175 return ("ERROR", "Bad request format") 

176 if oid not in self.objtable: 

177 return ("ERROR", "Unknown object id: %r" % (oid,)) 

178 obj = self.objtable[oid] 

179 if methodname == "__methods__": 

180 methods = {} 

181 _getmethods(obj, methods) 

182 return ("OK", methods) 

183 if methodname == "__attributes__": 

184 attributes = {} 

185 _getattributes(obj, attributes) 

186 return ("OK", attributes) 

187 if not hasattr(obj, methodname): 

188 return ("ERROR", "Unsupported method name: %r" % (methodname,)) 

189 method = getattr(obj, methodname) 

190 try: 

191 if how == 'CALL': 

192 ret = method(*args, **kwargs) 

193 if isinstance(ret, RemoteObject): 

194 ret = remoteref(ret) 

195 return ("OK", ret) 

196 elif how == 'QUEUE': 

197 request_queue.put((seq, (method, args, kwargs))) 

198 return("QUEUED", None) 

199 else: 

200 return ("ERROR", "Unsupported message type: %s" % how) 

201 except SystemExit: 

202 raise 

203 except KeyboardInterrupt: 

204 raise 

205 except OSError: 

206 raise 

207 except Exception as ex: 

208 return ("CALLEXC", ex) 

209 except: 

210 msg = "*** Internal Error: rpc.py:SocketIO.localcall()\n\n"\ 

211 " Object: %s \n Method: %s \n Args: %s\n" 

212 print(msg % (oid, method, args), file=sys.__stderr__) 

213 traceback.print_exc(file=sys.__stderr__) 

214 return ("EXCEPTION", None) 

215 

216 def remotecall(self, oid, methodname, args, kwargs): 

217 self.debug("remotecall:asynccall: ", oid, methodname) 

218 seq = self.asynccall(oid, methodname, args, kwargs) 

219 return self.asyncreturn(seq) 

220 

221 def remotequeue(self, oid, methodname, args, kwargs): 

222 self.debug("remotequeue:asyncqueue: ", oid, methodname) 

223 seq = self.asyncqueue(oid, methodname, args, kwargs) 

224 return self.asyncreturn(seq) 

225 

226 def asynccall(self, oid, methodname, args, kwargs): 

227 request = ("CALL", (oid, methodname, args, kwargs)) 

228 seq = self.newseq() 

229 if threading.current_thread() != self.sockthread: 

230 cvar = threading.Condition() 

231 self.cvars[seq] = cvar 

232 self.debug(("asynccall:%d:" % seq), oid, methodname, args, kwargs) 

233 self.putmessage((seq, request)) 

234 return seq 

235 

236 def asyncqueue(self, oid, methodname, args, kwargs): 

237 request = ("QUEUE", (oid, methodname, args, kwargs)) 

238 seq = self.newseq() 

239 if threading.current_thread() != self.sockthread: 

240 cvar = threading.Condition() 

241 self.cvars[seq] = cvar 

242 self.debug(("asyncqueue:%d:" % seq), oid, methodname, args, kwargs) 

243 self.putmessage((seq, request)) 

244 return seq 

245 

246 def asyncreturn(self, seq): 

247 self.debug("asyncreturn:%d:call getresponse(): " % seq) 

248 response = self.getresponse(seq, wait=0.05) 

249 self.debug(("asyncreturn:%d:response: " % seq), response) 

250 return self.decoderesponse(response) 

251 

252 def decoderesponse(self, response): 

253 how, what = response 

254 if how == "OK": 

255 return what 

256 if how == "QUEUED": 

257 return None 

258 if how == "EXCEPTION": 

259 self.debug("decoderesponse: EXCEPTION") 

260 return None 

261 if how == "EOF": 

262 self.debug("decoderesponse: EOF") 

263 self.decode_interrupthook() 

264 return None 

265 if how == "ERROR": 

266 self.debug("decoderesponse: Internal ERROR:", what) 

267 raise RuntimeError(what) 

268 if how == "CALLEXC": 

269 self.debug("decoderesponse: Call Exception:", what) 

270 raise what 

271 raise SystemError(how, what) 

272 

273 def decode_interrupthook(self): 

274 "" 

275 raise EOFError 

276 

277 def mainloop(self): 

278 """Listen on socket until I/O not ready or EOF 

279 

280 pollresponse() will loop looking for seq number None, which 

281 never comes, and exit on EOFError. 

282 

283 """ 

284 try: 

285 self.getresponse(myseq=None, wait=0.05) 

286 except EOFError: 

287 self.debug("mainloop:return") 

288 return 

289 

290 def getresponse(self, myseq, wait): 

291 response = self._getresponse(myseq, wait) 

292 if response is not None: 

293 how, what = response 

294 if how == "OK": 

295 response = how, self._proxify(what) 

296 return response 

297 

298 def _proxify(self, obj): 

299 if isinstance(obj, RemoteProxy): 

300 return RPCProxy(self, obj.oid) 

301 if isinstance(obj, list): 

302 return list(map(self._proxify, obj)) 

303 # XXX Check for other types -- not currently needed 

304 return obj 

305 

306 def _getresponse(self, myseq, wait): 

307 self.debug("_getresponse:myseq:", myseq) 

308 if threading.current_thread() is self.sockthread: 

309 # this thread does all reading of requests or responses 

310 while True: 

311 response = self.pollresponse(myseq, wait) 

312 if response is not None: 

313 return response 

314 else: 

315 # wait for notification from socket handling thread 

316 cvar = self.cvars[myseq] 

317 cvar.acquire() 

318 while myseq not in self.responses: 

319 cvar.wait() 

320 response = self.responses[myseq] 

321 self.debug("_getresponse:%s: thread woke up: response: %s" % 

322 (myseq, response)) 

323 del self.responses[myseq] 

324 del self.cvars[myseq] 

325 cvar.release() 

326 return response 

327 

328 def newseq(self): 

329 self.nextseq = seq = self.nextseq + 2 

330 return seq 

331 

332 def putmessage(self, message): 

333 self.debug("putmessage:%d:" % message[0]) 

334 try: 

335 s = dumps(message) 

336 except pickle.PicklingError: 

337 print("Cannot pickle:", repr(message), file=sys.__stderr__) 

338 raise 

339 s = struct.pack("<i", len(s)) + s 

340 while len(s) > 0: 

341 try: 

342 r, w, x = select.select([], [self.sock], []) 

343 n = self.sock.send(s[:BUFSIZE]) 

344 except (AttributeError, TypeError): 

345 raise OSError("socket no longer exists") 

346 s = s[n:] 

347 

348 buff = b'' 

349 bufneed = 4 

350 bufstate = 0 # meaning: 0 => reading count; 1 => reading data 

351 

352 def pollpacket(self, wait): 

353 self._stage0() 

354 if len(self.buff) < self.bufneed: 

355 r, w, x = select.select([self.sock.fileno()], [], [], wait) 

356 if len(r) == 0: 

357 return None 

358 try: 

359 s = self.sock.recv(BUFSIZE) 

360 except OSError: 

361 raise EOFError 

362 if len(s) == 0: 

363 raise EOFError 

364 self.buff += s 

365 self._stage0() 

366 return self._stage1() 

367 

368 def _stage0(self): 

369 if self.bufstate == 0 and len(self.buff) >= 4: 

370 s = self.buff[:4] 

371 self.buff = self.buff[4:] 

372 self.bufneed = struct.unpack("<i", s)[0] 

373 self.bufstate = 1 

374 

375 def _stage1(self): 

376 if self.bufstate == 1 and len(self.buff) >= self.bufneed: 

377 packet = self.buff[:self.bufneed] 

378 self.buff = self.buff[self.bufneed:] 

379 self.bufneed = 4 

380 self.bufstate = 0 

381 return packet 

382 

383 def pollmessage(self, wait): 

384 packet = self.pollpacket(wait) 

385 if packet is None: 

386 return None 

387 try: 

388 message = pickle.loads(packet) 

389 except pickle.UnpicklingError: 

390 print("-----------------------", file=sys.__stderr__) 

391 print("cannot unpickle packet:", repr(packet), file=sys.__stderr__) 

392 traceback.print_stack(file=sys.__stderr__) 

393 print("-----------------------", file=sys.__stderr__) 

394 raise 

395 return message 

396 

397 def pollresponse(self, myseq, wait): 

398 """Handle messages received on the socket. 

399 

400 Some messages received may be asynchronous 'call' or 'queue' requests, 

401 and some may be responses for other threads. 

402 

403 'call' requests are passed to self.localcall() with the expectation of 

404 immediate execution, during which time the socket is not serviced. 

405 

406 'queue' requests are used for tasks (which may block or hang) to be 

407 processed in a different thread. These requests are fed into 

408 request_queue by self.localcall(). Responses to queued requests are 

409 taken from response_queue and sent across the link with the associated 

410 sequence numbers. Messages in the queues are (sequence_number, 

411 request/response) tuples and code using this module removing messages 

412 from the request_queue is responsible for returning the correct 

413 sequence number in the response_queue. 

414 

415 pollresponse() will loop until a response message with the myseq 

416 sequence number is received, and will save other responses in 

417 self.responses and notify the owning thread. 

418 

419 """ 

420 while True: 

421 # send queued response if there is one available 

422 try: 

423 qmsg = response_queue.get(0) 

424 except queue.Empty: 

425 pass 

426 else: 

427 seq, response = qmsg 

428 message = (seq, ('OK', response)) 

429 self.putmessage(message) 

430 # poll for message on link 

431 try: 

432 message = self.pollmessage(wait) 

433 if message is None: # socket not ready 

434 return None 

435 except EOFError: 

436 self.handle_EOF() 

437 return None 

438 except AttributeError: 

439 return None 

440 seq, resq = message 

441 how = resq[0] 

442 self.debug("pollresponse:%d:myseq:%s" % (seq, myseq)) 

443 # process or queue a request 

444 if how in ("CALL", "QUEUE"): 

445 self.debug("pollresponse:%d:localcall:call:" % seq) 

446 response = self.localcall(seq, resq) 

447 self.debug("pollresponse:%d:localcall:response:%s" 

448 % (seq, response)) 

449 if how == "CALL": 

450 self.putmessage((seq, response)) 

451 elif how == "QUEUE": 

452 # don't acknowledge the 'queue' request! 

453 pass 

454 continue 

455 # return if completed message transaction 

456 elif seq == myseq: 

457 return resq 

458 # must be a response for a different thread: 

459 else: 

460 cv = self.cvars.get(seq, None) 

461 # response involving unknown sequence number is discarded, 

462 # probably intended for prior incarnation of server 

463 if cv is not None: 

464 cv.acquire() 

465 self.responses[seq] = resq 

466 cv.notify() 

467 cv.release() 

468 continue 

469 

470 def handle_EOF(self): 

471 "action taken upon link being closed by peer" 

472 self.EOFhook() 

473 self.debug("handle_EOF") 

474 for key in self.cvars: 

475 cv = self.cvars[key] 

476 cv.acquire() 

477 self.responses[key] = ('EOF', None) 

478 cv.notify() 

479 cv.release() 

480 # call our (possibly overridden) exit function 

481 self.exithook() 

482 

483 def EOFhook(self): 

484 "Classes using rpc client/server can override to augment EOF action" 

485 pass 

486 

487#----------------- end class SocketIO -------------------- 

488 

489class RemoteObject: 

490 # Token mix-in class 

491 pass 

492 

493 

494def remoteref(obj): 

495 oid = id(obj) 

496 objecttable[oid] = obj 

497 return RemoteProxy(oid) 

498 

499 

500class RemoteProxy: 

501 

502 def __init__(self, oid): 

503 self.oid = oid 

504 

505 

506class RPCHandler(socketserver.BaseRequestHandler, SocketIO): 

507 

508 debugging = False 

509 location = "#S" # Server 

510 

511 def __init__(self, sock, addr, svr): 

512 svr.current_handler = self ## cgt xxx 

513 SocketIO.__init__(self, sock) 

514 socketserver.BaseRequestHandler.__init__(self, sock, addr, svr) 

515 

516 def handle(self): 

517 "handle() method required by socketserver" 

518 self.mainloop() 

519 

520 def get_remote_proxy(self, oid): 

521 return RPCProxy(self, oid) 

522 

523 

524class RPCClient(SocketIO): 

525 

526 debugging = False 

527 location = "#C" # Client 

528 

529 nextseq = 1 # Requests coming from the client are odd numbered 

530 

531 def __init__(self, address, family=socket.AF_INET, type=socket.SOCK_STREAM): 

532 self.listening_sock = socket.socket(family, type) 

533 self.listening_sock.bind(address) 

534 self.listening_sock.listen(1) 

535 

536 def accept(self): 

537 working_sock, address = self.listening_sock.accept() 

538 if self.debugging: 

539 print("****** Connection request from ", address, file=sys.__stderr__) 

540 if address[0] == LOCALHOST: 

541 SocketIO.__init__(self, working_sock) 

542 else: 

543 print("** Invalid host: ", address, file=sys.__stderr__) 

544 raise OSError 

545 

546 def get_remote_proxy(self, oid): 

547 return RPCProxy(self, oid) 

548 

549 

550class RPCProxy: 

551 

552 __methods = None 

553 __attributes = None 

554 

555 def __init__(self, sockio, oid): 

556 self.sockio = sockio 

557 self.oid = oid 

558 

559 def __getattr__(self, name): 

560 if self.__methods is None: 

561 self.__getmethods() 

562 if self.__methods.get(name): 

563 return MethodProxy(self.sockio, self.oid, name) 

564 if self.__attributes is None: 

565 self.__getattributes() 

566 if name in self.__attributes: 

567 value = self.sockio.remotecall(self.oid, '__getattribute__', 

568 (name,), {}) 

569 return value 

570 else: 

571 raise AttributeError(name) 

572 

573 def __getattributes(self): 

574 self.__attributes = self.sockio.remotecall(self.oid, 

575 "__attributes__", (), {}) 

576 

577 def __getmethods(self): 

578 self.__methods = self.sockio.remotecall(self.oid, 

579 "__methods__", (), {}) 

580 

581def _getmethods(obj, methods): 

582 # Helper to get a list of methods from an object 

583 # Adds names to dictionary argument 'methods' 

584 for name in dir(obj): 

585 attr = getattr(obj, name) 

586 if callable(attr): 

587 methods[name] = 1 

588 if isinstance(obj, type): 

589 for super in obj.__bases__: 

590 _getmethods(super, methods) 

591 

592def _getattributes(obj, attributes): 

593 for name in dir(obj): 

594 attr = getattr(obj, name) 

595 if not callable(attr): 

596 attributes[name] = 1 

597 

598 

599class MethodProxy: 

600 

601 def __init__(self, sockio, oid, name): 

602 self.sockio = sockio 

603 self.oid = oid 

604 self.name = name 

605 

606 def __call__(self, /, *args, **kwargs): 

607 value = self.sockio.remotecall(self.oid, self.name, args, kwargs) 

608 return value 

609 

610 

611# XXX KBK 09Sep03 We need a proper unit test for this module. Previously 

612# existing test code was removed at Rev 1.27 (r34098). 

613 

614def displayhook(value): 

615 """Override standard display hook to use non-locale encoding""" 

616 if value is None: 

617 return 

618 # Set '_' to None to avoid recursion 

619 builtins._ = None 

620 text = repr(value) 

621 try: 

622 sys.stdout.write(text) 

623 except UnicodeEncodeError: 

624 # let's use ascii while utf8-bmp codec doesn't present 

625 encoding = 'ascii' 

626 bytes = text.encode(encoding, 'backslashreplace') 

627 text = bytes.decode(encoding, 'strict') 

628 sys.stdout.write(text) 

629 sys.stdout.write("\n") 

630 builtins._ = value 

631 

632 

633if __name__ == '__main__': 633 ↛ 634line 633 didn't jump to line 634, because the condition on line 633 was never true

634 from unittest import main 

635 main('idlelib.idle_test.test_rpc', verbosity=2,)