|
@@ -1,8 +1,9 @@
|
|
|
from argparse import Namespace
|
|
|
from importlib import import_module
|
|
|
+import io
|
|
|
import re
|
|
|
import sys
|
|
|
-from typing import Any, Dict, List, Tuple, TYPE_CHECKING
|
|
|
+from typing import Any, BinaryIO, Dict, List, Tuple, TYPE_CHECKING
|
|
|
|
|
|
import numpy
|
|
|
import yaml
|
|
@@ -14,7 +15,7 @@ if TYPE_CHECKING:
|
|
|
|
|
|
|
|
|
class Reader(object):
|
|
|
- def read(self, file):
|
|
|
+ def read(self, fileobj: BinaryIO):
|
|
|
raise NotImplementedError
|
|
|
|
|
|
def get_name(self):
|
|
@@ -61,7 +62,7 @@ class Reader(object):
|
|
|
|
|
|
|
|
|
class YamlReader(Reader):
|
|
|
- def read(self, file: str):
|
|
|
+ def read(self, fileobj: BinaryIO):
|
|
|
yaml.reader.Reader.NON_PRINTABLE = re.compile(r"(?!x)x")
|
|
|
try:
|
|
|
loader = yaml.CLoader
|
|
@@ -71,12 +72,9 @@ class YamlReader(Reader):
|
|
|
)
|
|
|
loader = yaml.Loader
|
|
|
try:
|
|
|
- if file != "-":
|
|
|
- with open(file) as fin:
|
|
|
- data = yaml.load(fin, Loader=loader)
|
|
|
- else:
|
|
|
- data = yaml.load(sys.stdin, Loader=loader)
|
|
|
- except (UnicodeEncodeError, yaml.reader.ReaderError) as e:
|
|
|
+ wrapper = io.TextIOWrapper(fileobj, encoding="utf-8")
|
|
|
+ data = yaml.load(wrapper, Loader=loader)
|
|
|
+ except (UnicodeEncodeError, UnicodeDecodeError, yaml.reader.ReaderError) as e:
|
|
|
print(
|
|
|
"\nInvalid unicode in the input: %s\nPlease filter it through "
|
|
|
"fix_yaml_unicode.py" % e
|
|
@@ -217,7 +215,7 @@ class YamlReader(Reader):
|
|
|
|
|
|
|
|
|
class ProtobufReader(Reader):
|
|
|
- def read(self, file: str) -> None:
|
|
|
+ def read(self, fileobj: BinaryIO) -> None:
|
|
|
try:
|
|
|
from labours.pb_pb2 import AnalysisResults
|
|
|
except ImportError as e:
|
|
@@ -227,14 +225,10 @@ class ProtobufReader(Reader):
|
|
|
)
|
|
|
raise e from None
|
|
|
self.data = AnalysisResults()
|
|
|
- if file != "-":
|
|
|
- with open(file, "rb") as fin:
|
|
|
- bytes = fin.read()
|
|
|
- else:
|
|
|
- bytes = sys.stdin.buffer.read()
|
|
|
- if not bytes:
|
|
|
+ all_bytes = fileobj.read()
|
|
|
+ if not all_bytes:
|
|
|
raise ValueError("empty input")
|
|
|
- self.data.ParseFromString(bytes)
|
|
|
+ self.data.ParseFromString(all_bytes)
|
|
|
self.contents = {}
|
|
|
for key, val in self.data.contents.items():
|
|
|
try:
|
|
@@ -370,23 +364,81 @@ PB_MESSAGES = {
|
|
|
}
|
|
|
|
|
|
|
|
|
+def chain_streams(streams, buffer_size=io.DEFAULT_BUFFER_SIZE):
|
|
|
+ """
|
|
|
+ Chain an iterable of streams together into a single buffered stream.
|
|
|
+ Source: https://stackoverflow.com/a/50770511
|
|
|
+
|
|
|
+ Usage:
|
|
|
+ f = chain_streams(open(f, "rb") for f in filenames)
|
|
|
+ f.read()
|
|
|
+ """
|
|
|
+
|
|
|
+ class ChainStream(io.RawIOBase):
|
|
|
+ def __init__(self):
|
|
|
+ self.leftover = b""
|
|
|
+ self.stream_iter = iter(streams)
|
|
|
+ try:
|
|
|
+ self.stream = next(self.stream_iter)
|
|
|
+ except StopIteration:
|
|
|
+ self.stream = None
|
|
|
+
|
|
|
+ def readable(self):
|
|
|
+ return True
|
|
|
+
|
|
|
+ def _read_next_chunk(self, max_length):
|
|
|
+ # Return 0 or more bytes from the current stream, first returning all
|
|
|
+ # leftover bytes. If the stream is closed returns b''
|
|
|
+ if self.leftover:
|
|
|
+ return self.leftover
|
|
|
+ elif self.stream is not None:
|
|
|
+ return self.stream.read(max_length)
|
|
|
+ else:
|
|
|
+ return b""
|
|
|
+
|
|
|
+ def readinto(self, b):
|
|
|
+ buffer_length = len(b)
|
|
|
+ chunk = self._read_next_chunk(buffer_length)
|
|
|
+ while len(chunk) == 0:
|
|
|
+ # move to next stream
|
|
|
+ if self.stream is not None:
|
|
|
+ self.stream.close()
|
|
|
+ try:
|
|
|
+ self.stream = next(self.stream_iter)
|
|
|
+ chunk = self._read_next_chunk(buffer_length)
|
|
|
+ except StopIteration:
|
|
|
+ # No more streams to chain together
|
|
|
+ self.stream = None
|
|
|
+ return 0 # indicate EOF
|
|
|
+ output, self.leftover = chunk[:buffer_length], chunk[buffer_length:]
|
|
|
+ b[:len(output)] = output
|
|
|
+ return len(output)
|
|
|
+
|
|
|
+ return io.BufferedReader(ChainStream(), buffer_size=buffer_size)
|
|
|
+
|
|
|
+
|
|
|
def read_input(args: Namespace) -> ProtobufReader:
|
|
|
sys.stdout.write("Reading the input... ")
|
|
|
sys.stdout.flush()
|
|
|
if args.input != "-":
|
|
|
+ stream = open(args.input, "rb")
|
|
|
+ else:
|
|
|
+ stream = sys.stdin.buffer
|
|
|
+ try:
|
|
|
if args.input_format == "auto":
|
|
|
+ buffer = stream.read(1 << 16)
|
|
|
try:
|
|
|
- args.input_format = args.input.rsplit(".", 1)[1]
|
|
|
- except IndexError:
|
|
|
- try:
|
|
|
- with open(args.input) as f:
|
|
|
- f.read(1 << 16)
|
|
|
- args.input_format = "yaml"
|
|
|
- except UnicodeDecodeError:
|
|
|
- args.input_format = "pb"
|
|
|
- elif args.input_format == "auto":
|
|
|
- args.input_format = "yaml"
|
|
|
- reader = READERS[args.input_format]()
|
|
|
- reader.read(args.input)
|
|
|
+ buffer.decode("utf-8")
|
|
|
+ args.input_format = "yaml"
|
|
|
+ except UnicodeDecodeError:
|
|
|
+ args.input_format = "pb"
|
|
|
+ ins = chain_streams((io.BytesIO(buffer), stream), len(buffer))
|
|
|
+ else:
|
|
|
+ ins = stream
|
|
|
+ reader = READERS[args.input_format]()
|
|
|
+ reader.read(ins)
|
|
|
+ finally:
|
|
|
+ if args.input != "-":
|
|
|
+ stream.close()
|
|
|
print("done")
|
|
|
return reader
|