diff --git a/tcd.py b/tcd.py index 50d96f0..7134e73 100644 --- a/tcd.py +++ b/tcd.py @@ -28,7 +28,7 @@ def converttcdfilename(filename): def transcribe(filename: str): - model = whisper.load_model("small").cuda() + model = whisper.load_model("small") # .cuda() result = model.transcribe(filename, language="en", temperature=0.0) assert result["language"] == "en" return result["text"].strip() @@ -58,6 +58,18 @@ def writelast(filename: str): lf.write("\n") +# https://stackoverflow.com/questions/5574702/how-do-i-print-to-stderr-in-python +def maybe_eprint(*args, **kwargs): + print(*args, file=sys.stderr, **kwargs) + + +def maybe_not_eprint(*args, **kwargs): + pass + + +eprint = maybe_not_eprint + + def main(): parser = argparse.ArgumentParser( description="Process Elf's Thought Capture Device" @@ -83,23 +95,47 @@ def main(): help="Ignore cached last-read value, process entire TCD.", ) - if not os.path.isdir(TCD_PATH): + parser.add_argument( + "-p", "--path", nargs=1, type=str, help="Path to TCD recordings" + ) + + parser.add_argument( + "-v", "--verbose", action="store_true", help="Show plan and progress." + ) + + args = vars(parser.parse_args()) + + tcd_path = TCD_PATH + + eprint = maybe_not_eprint + if args and "verbose" in args and args["verbose"]: + eprint = maybe_eprint + + if ( + args + and "path" in args + and args["path"] is not None + and len(args["path"]) > 0 + ): + tcd_path = args["path"][0] + eprint("Setting repository path to: {}".format(tcd_path)) + + if not os.path.isdir(tcd_path): print("Could not identify path to TCD repository. Is it mounted?") sys.exit(1) - files = [p for p in os.listdir(TCD_PATH) if tcd_record.match(p)] + files = [p for p in os.listdir(tcd_path) if tcd_record.match(p)] if len(files) < 1: print("No files in TCD repository to process.") sys.exit(0) - args = vars(parser.parse_args()) - if ( args and "file" in args and args["file"] is not None and len(args["file"]) > 0 ): + eprint("Transcribing {}".format(args["file"][0])) transcribe(args["file"][0]) sys.exit(0) @@ -122,18 +158,21 @@ def main(): ) ) sys.exit(1) + eprint("Transcribing {}".format(args["file"][pos])) transcribe(args["file"][pos]) sys.exit(0) if args and "all" in args and args["all"]: + eprint("Transcribing all files in repository {}".format(tcd_path)) for index, f in enumerate(files): - transcribe_with_timestamp(os.path.join(TCD_PATH, f)) + transcribe_with_timestamp(os.path.join(tcd_path, f)) if index != len(files): print("\n\n") sys.exit(0) lastfile = getlastfile() files_to_transcribe = [f for f in files] + files_to_transcribe.sort() if lastfile: try: lastindex = files_to_transcribe.index(lastfile) @@ -145,12 +184,13 @@ def main(): print("NOTICE: No new entries found to transcribe.") sys.exit(0) + eprint("Transcribing: {}".format(", ".join(files_to_transcribe))) for index, f in enumerate(files_to_transcribe): - transcribe_with_timestamp(os.path.join(TCD_PATH, f)) + transcribe_with_timestamp(os.path.join(tcd_path, f)) if index != len(files): print("\n\n") - writelast(files[-1]) + writelast(files_to_transcribe[-1]) sys.exit(0)