diff options
author | Konstantin Ryabitsev <konstantin@linuxfoundation.org> | 2024-01-23 13:28:04 -0500 |
---|---|---|
committer | Konstantin Ryabitsev <konstantin@linuxfoundation.org> | 2024-01-23 13:28:04 -0500 |
commit | c3bdb04772565ccd2a20ddd990f5f1c2b4dc77ed (patch) | |
tree | 712e2672a4dcf8cb30193eb0d6aaa51f9cffcbf1 | |
parent | f0463198e476b04df470f2170d5d3f2fce5eefa0 (diff) | |
download | b4-c3bdb04772565ccd2a20ddd990f5f1c2b4dc77ed.tar.gz |
Add typing hints where still missing
We still had quite a few places where we had no typing hints, so add it
in most places where that was relevant.
Signed-off-by: Konstantin Ryabitsev <konstantin@linuxfoundation.org>
-rw-r--r-- | b4/__init__.py | 86 | ||||
-rw-r--r-- | b4/diff.py | 13 | ||||
-rw-r--r-- | b4/pr.py | 18 | ||||
-rw-r--r-- | b4/ty.py | 45 |
4 files changed, 92 insertions, 70 deletions
diff --git a/b4/__init__.py b/b4/__init__.py index 5b692a1..21441eb 100644 --- a/b4/__init__.py +++ b/b4/__init__.py @@ -34,7 +34,7 @@ import requests from pathlib import Path from contextlib import contextmanager -from typing import Optional, Tuple, Set, List, BinaryIO, Union, Sequence, Literal +from typing import Optional, Tuple, Set, List, BinaryIO, Union, Sequence, Literal, Iterator, Dict from email import charset @@ -162,11 +162,11 @@ MAILMAP_INFO = dict() class LoreMailbox: - msgid_map: dict - series: dict - covers: dict - followups: list - unknowns: list + msgid_map: Dict[str, 'LoreMessage'] + series: Dict[int, 'LoreSeries'] + covers: Dict[int, 'LoreMessage'] + followups: List['LoreMessage'] + unknowns: List['LoreMessage'] def __init__(self): self.msgid_map = dict() @@ -878,7 +878,7 @@ class LoreSeries: raise IndexError - def make_fake_am_range(self, gitdir): + def make_fake_am_range(self, gitdir: Optional[str]) -> Tuple[Optional[str], Optional[str]]: start_commit = end_commit = None # Use the msgid of the first non-None patch in the series msgid = None @@ -1080,14 +1080,14 @@ class LoreTrailer: return False - def __eq__(self, other): + def __eq__(self, other: 'LoreTrailer') -> bool: # We never compare extinfo, we just tack it if we find a match return self.lname == other.lname and self.value.lower() == other.value.lower() - def __hash__(self): + def __hash__(self) -> hash: return hash(f'{self.lname}: {self.value}') - def __repr__(self): + def __repr__(self) -> str: out = list() out.append(' type: %s' % self.type) out.append(' name: %s' % self.name) @@ -1706,7 +1706,7 @@ class LoreMessage: return mbody, mcharset @staticmethod - def clean_header(hdrval): + def clean_header(hdrval: Optional[str]) -> str: if hdrval is None: return '' @@ -1825,7 +1825,7 @@ class LoreMessage: return hdata @staticmethod - def get_clean_msgid(msg: email.message.Message, header='Message-Id') -> str: + def get_clean_msgid(msg: email.message.Message, header: str = 'Message-Id') -> str: msgid = None raw = msg.get(header) if raw: @@ -2253,7 +2253,7 @@ class LoreMessage: self.body = LoreMessage.rebuild_message(bheaders, message, fixtrailers, basement, signature) def get_am_subject(self, indicate_reroll: bool = True, use_subject: Optional[str] = None, - show_ci_status: bool = True): + show_ci_status: bool = True) -> str: # Return a clean patch subject parts = ['PATCH'] if self.lsubject.rfc: @@ -2281,7 +2281,7 @@ class LoreMessage: def get_am_message(self, add_trailers: bool = True, addmysob: bool = False, extras: Optional[List['LoreTrailer']] = None, copyccs: bool = False, - allowbadchars: bool = False): + allowbadchars: bool = False) -> email.message.EmailMessage: # Look through the body to make sure there aren't any suspicious unicode control flow chars # First, encode into ascii and compare for a quickie utf8 presence test if not allowbadchars and self.body.encode('ascii', errors='replace') != self.body.encode(): @@ -2335,10 +2335,21 @@ class LoreMessage: class LoreSubject: + full_subject: str + subject: str + reply: bool + resend: bool + patch: bool + rfc: bool + revision: int + counter: int + expected: int + revision_inferred: bool + counters_inferred: bool + prefixes: List[str] + def __init__(self, subject): # Subject-based info - self.full_subject = None - self.subject = None self.reply = False self.resend = False self.patch = False @@ -2414,7 +2425,7 @@ class LoreSubject: return ret - def get_rebuilt_subject(self, eprefixes: Optional[List[str]] = None): + def get_rebuilt_subject(self, eprefixes: Optional[List[str]] = None) -> str: _pfx = self.get_extra_prefixes() if eprefixes: for _epfx in eprefixes: @@ -2430,13 +2441,13 @@ class LoreSubject: else: return f'[PATCH] {self.subject}' - def get_slug(self, sep='_', with_counter: bool = True): + def get_slug(self, sep='_', with_counter: bool = True) -> str: unsafe = self.subject if with_counter: unsafe = '%04d%s%s' % (self.counter, sep, unsafe) return re.sub(r'\W+', sep, unsafe).strip(sep).lower() - def __repr__(self): + def __repr__(self) -> str: out = list() out.append(' full_subject: %s' % self.full_subject) out.append(' subject: %s' % self.subject) @@ -2672,7 +2683,7 @@ def git_get_repo_status(gitdir: Optional[str] = None, untracked: bool = False) - @contextmanager -def git_temp_worktree(gitdir=None, commitish=None): +def git_temp_worktree(gitdir: Optional[str] = None, commitish: Optional[str] = None) -> Optional[Iterator[Path]]: """Context manager that creates a temporary work tree and chdirs into it. The worktree is deleted when the contex manager is closed. Taken from gj_tools.""" dfn = None @@ -2690,7 +2701,7 @@ def git_temp_worktree(gitdir=None, commitish=None): @contextmanager -def git_temp_clone(gitdir=None): +def git_temp_clone(gitdir: Optional[str] = None) -> Optional[Iterator[Path]]: """Context manager that creates a temporary shared clone.""" if gitdir is None: topdir = git_get_toplevel() @@ -2708,7 +2719,7 @@ def git_temp_clone(gitdir=None): @contextmanager -def in_directory(dirname): +def in_directory(dirname: str) -> Iterator[bool]: """Context manager that chdirs into a directory and restores the original directory when closed. Taken from gj_tools.""" cdir = os.getcwd() @@ -2839,7 +2850,7 @@ def get_cache_dir(appname: str = 'b4') -> str: return cachedir -def get_cache_file(identifier: str, suffix: Optional[str] = None): +def get_cache_file(identifier: str, suffix: Optional[str] = None) -> str: cachedir = get_cache_dir() cachefile = hashlib.sha1(identifier.encode()).hexdigest() if suffix: @@ -2875,7 +2886,7 @@ def save_cache(contents: str, identifier: str, suffix: Optional[str] = None, mod logger.debug('Could not write cache %s for %s', fullpath, identifier) -def get_user_config(): +def get_user_config() -> dict: global USER_CONFIG if USER_CONFIG is None: USER_CONFIG = get_config_from_git(r'user\..*') @@ -2887,7 +2898,7 @@ def get_user_config(): return USER_CONFIG -def get_requests_session(): +def get_requests_session() -> requests.Session: global REQSESSION if REQSESSION is None: REQSESSION = requests.session() @@ -2942,7 +2953,8 @@ def get_msgid(cmdargs: argparse.Namespace) -> Optional[str]: return msgid -def get_strict_thread(msgs, msgid, noparent=False): +def get_strict_thread(msgs: Union[List[email.message.Message], mailbox.Mailbox, mailbox.Maildir], + msgid: str, noparent: bool = False) -> Optional[List[email.message.Message]]: want = {msgid} ignore = set() got = set() @@ -3109,7 +3121,7 @@ def split_and_dedupe_pi_results(t_mbox: bytes, cachedir: Optional[str] = None) - return msgs -def get_pi_thread_by_url(t_mbx_url: str, nocache: bool = False): +def get_pi_thread_by_url(t_mbx_url: str, nocache: bool = False) -> Optional[List[email.message.Message]]: msgs = list() cachedir = get_cache_file(t_mbx_url, 'pi.msgs') if os.path.exists(cachedir) and not nocache: @@ -3138,7 +3150,7 @@ def get_pi_thread_by_url(t_mbx_url: str, nocache: bool = False): def get_pi_thread_by_msgid(msgid: str, nocache: bool = False, - onlymsgids: Optional[set] = None) -> Optional[list]: + onlymsgids: Optional[set] = None) -> Optional[List[email.message.Message]]: qmsgid = urllib.parse.quote_plus(msgid, safe='@') config = get_main_config() loc = urllib.parse.urlparse(config['midmask']) @@ -3281,7 +3293,7 @@ def git_commit_exists(gitdir: Optional[str], commit_id: str) -> bool: return ecode == 0 -def git_branch_exists(gitdir: Optional[str], branch_name: str): +def git_branch_exists(gitdir: Optional[str], branch_name: str) -> bool: gitargs = ['rev-parse', branch_name] ecode, out = git_run_command(gitdir, gitargs) return ecode == 0 @@ -3317,7 +3329,7 @@ def git_get_toplevel(path: Optional[str] = None) -> Optional[str]: return topdir -def format_addrs(pairs, clean=True): +def format_addrs(pairs: List[Tuple[str, str]], clean: bool = True) -> str: addrs = list() for pair in pairs: if pair[0] == pair[1]: @@ -3335,7 +3347,7 @@ def format_addrs(pairs, clean=True): return ', '.join(addrs) -def make_quote(body, maxlines=5): +def make_quote(body: str, maxlines: int = 5) -> str: headers, message, trailers, basement, signature = LoreMessage.get_body_parts(body) if not len(message): # Sometimes there is no message, just trailers @@ -3355,7 +3367,7 @@ def make_quote(body, maxlines=5): return '\n'.join(quotelines) -def parse_int_range(intrange, upper=None): +def parse_int_range(intrange: str, upper: Optional[int] = None) -> Iterator[int]: # Remove all whitespace intrange = re.sub(r'\s', '', intrange) for n in intrange.split(','): @@ -3401,7 +3413,7 @@ def check_gpg_status(status: str) -> Tuple[bool, bool, bool, Optional[str], Opti return good, valid, trusted, keyid, signtime -def get_gpg_uids(keyid: str) -> list: +def get_gpg_uids(keyid: str) -> List[str]: gpgargs = ['--with-colons', '--list-keys', keyid] ecode, out, err = gpg_run_command(gpgargs) if ecode > 0: @@ -3421,7 +3433,7 @@ def get_gpg_uids(keyid: str) -> list: return uids -def save_git_am_mbox(msgs: List[email.message.Message], dest: BinaryIO): +def save_git_am_mbox(msgs: List[email.message.Message], dest: BinaryIO) -> None: # Git-am has its own understanding of what "mbox" format is that differs from Python's # mboxo implementation. Specifically, it never escapes the ">From " lines found in bodies # unless invoked with --patch-format=mboxrd (this is wrong, because ">From " escapes are also @@ -3432,14 +3444,14 @@ def save_git_am_mbox(msgs: List[email.message.Message], dest: BinaryIO): dest.write(LoreMessage.get_msg_as_bytes(msg, headers='decode')) -def save_mboxrd_mbox(msgs: List[email.message.Message], dest: BinaryIO, mangle_from: bool = False): +def save_mboxrd_mbox(msgs: List[email.message.Message], dest: BinaryIO, mangle_from: bool = False) -> None: gen = email.generator.BytesGenerator(dest, mangle_from_=mangle_from, policy=emlpolicy) for msg in msgs: dest.write(b'From mboxrd@z Thu Jan 1 00:00:00 1970\n') gen.flatten(msg) -def save_maildir(msgs: list, dest): +def save_maildir(msgs: list, dest) -> None: d_new = os.path.join(dest, 'new') pathlib.Path(d_new).mkdir(parents=True) d_cur = os.path.join(dest, 'cur') @@ -3485,7 +3497,7 @@ def get_mailinfo(bmsg: bytes, scissors: bool = False) -> Tuple[dict, bytes, byte return i, m, p -def read_template(tptfile): +def read_template(tptfile: str) -> str: # bubbles up FileNotFound tpt = '' if tptfile.find('~') >= 0: @@ -13,11 +13,14 @@ import mailbox import email import shutil import pathlib +import argparse + +from typing import Tuple, Optional, List logger = b4.logger -def diff_same_thread_series(cmdargs): +def diff_same_thread_series(cmdargs: argparse.Namespace) -> Tuple[Optional[b4.LoreSeries], Optional[b4.LoreSeries]]: msgid = b4.get_msgid(cmdargs) wantvers = cmdargs.wantvers if wantvers and len(wantvers) > 2: @@ -41,7 +44,7 @@ def diff_same_thread_series(cmdargs): msgs = b4.get_pi_thread_by_msgid(msgid, nocache=cmdargs.nocache) if not msgs: logger.critical('Unable to retrieve thread: %s', msgid) - return + return None, None msgs = b4.mbox.get_extra_series(msgs, direction=-1, wantvers=wantvers) if os.path.exists(cachedir): shutil.rmtree(cachedir) @@ -89,12 +92,12 @@ def diff_same_thread_series(cmdargs): return lmbx.series[lower], lmbx.series[upper] -def diff_mboxes(cmdargs): +def diff_mboxes(cmdargs: argparse.Namespace) -> Optional[List[b4.LoreSeries]]: chunks = list() for mboxfile in cmdargs.ambox: if not os.path.exists(mboxfile): logger.critical('Cannot open %s', mboxfile) - return None, None + return None if os.path.isdir(mboxfile): mbx = mailbox.Maildir(mboxfile) @@ -116,7 +119,7 @@ def diff_mboxes(cmdargs): return chunks -def main(cmdargs): +def main(cmdargs: argparse.Namespace) -> None: if cmdargs.ambox is not None: lser, user = diff_mboxes(cmdargs) else: @@ -13,6 +13,7 @@ import b4 import re import json import email +import argparse import urllib.parse import requests @@ -42,7 +43,7 @@ PULL_BODY_REMOTE_REF_RE = [ ] -def git_get_commit_id_from_repo_ref(repo, ref): +def git_get_commit_id_from_repo_ref(repo: str, ref: str) -> Optional[str]: # We only handle git and http/s URLs if not (repo.find('git://') == 0 or repo.find('http://') == 0 or repo.find('https://') == 0): logger.info('%s uses unsupported protocol', repo) @@ -76,7 +77,7 @@ def git_get_commit_id_from_repo_ref(repo, ref): return commit_id -def parse_pr_data(msg): +def parse_pr_data(msg: email.message.Message) -> Optional[b4.LoreMessage]: lmsg = b4.LoreMessage(msg) if lmsg.body is None: logger.critical('Could not find a plain part in the message body') @@ -113,7 +114,7 @@ def parse_pr_data(msg): return lmsg -def attest_fetch_head(gitdir, lmsg): +def attest_fetch_head(gitdir: Optional[str], lmsg: b4.LoreMessage) -> None: config = b4.get_main_config() attpolicy = config['attestation-policy'] if config['attestation-checkmarks'] == 'fancy': @@ -177,7 +178,8 @@ def attest_fetch_head(gitdir, lmsg): sys.exit(128) -def fetch_remote(gitdir, lmsg, branch=None, check_sig=True, ty_track=True): +def fetch_remote(gitdir: Optional[str], lmsg: b4.LoreMessage, branch: Optional[str] = None, + check_sig: bool = True, ty_track: bool = True) -> int: # Do we know anything about this base commit? if lmsg.pr_base_commit and not b4.git_commit_exists(gitdir, lmsg.pr_base_commit): logger.critical('ERROR: git knows nothing about commit %s', lmsg.pr_base_commit) @@ -220,7 +222,7 @@ def fetch_remote(gitdir, lmsg, branch=None, check_sig=True, ty_track=True): return 0 -def thanks_record_pr(lmsg): +def thanks_record_pr(lmsg: b4.LoreMessage) -> None: datadir = b4.get_data_dir() # Check if we're tracking it already filename = '%s.pr' % lmsg.pr_remote_tip_commit @@ -250,7 +252,7 @@ def thanks_record_pr(lmsg): config = b4.get_main_config() pwstate = config.get('pw-review-state') if pwstate: - b4.patchwork_set_state(lmsg.msgid, pwstate) + b4.patchwork_set_state([lmsg.msgid], pwstate) def explode(gitdir: Optional[str], lmsg: b4.LoreMessage, @@ -330,7 +332,7 @@ def explode(gitdir: Optional[str], lmsg: b4.LoreMessage, return msgs -def get_pr_from_github(ghurl: str): +def get_pr_from_github(ghurl: str) -> Optional[b4.LoreMessage]: loc = urllib.parse.urlparse(ghurl) chunks = loc.path.strip('/').split('/') rproj = chunks[0] @@ -389,7 +391,7 @@ def get_pr_from_github(ghurl: str): return lmsg -def main(cmdargs): +def main(cmdargs: argparse.Namespace) -> None: gitdir = cmdargs.gitdir lmsg = None @@ -14,11 +14,14 @@ import email import email.message import email.policy import json +import argparse from string import Template from email import utils from pathlib import Path +from typing import Optional, Tuple, Union, List, Dict + logger = b4.logger DEFAULT_PR_TEMPLATE = """ @@ -53,7 +56,7 @@ MY_COMMITS = None BRANCH_INFO = None -def git_get_merge_id(gitdir, commit_id, branch=None): +def git_get_merge_id(gitdir: Optional[str], commit_id: str, branch: Optional[str] = None) -> Optional[str]: # get merge commit id args = ['rev-list', '%s..' % commit_id, '--ancestry-path'] if branch is not None: @@ -64,17 +67,17 @@ def git_get_merge_id(gitdir, commit_id, branch=None): return lines[-1] -def git_get_rev_diff(gitdir, rev): +def git_get_rev_diff(gitdir: Optional[str], rev: str) -> Tuple[int, Union[str, bytes]]: args = ['diff', '%s~..%s' % (rev, rev)] return b4.git_run_command(gitdir, args) -def git_get_commit_message(gitdir, rev): +def git_get_commit_message(gitdir: Optional[str], rev: str) -> Tuple[int, Union[str, bytes]]: args = ['log', '--format=%B', '-1', rev] return b4.git_run_command(gitdir, args) -def make_reply(reply_template, jsondata, gitdir): +def make_reply(reply_template: str, jsondata: dict, gitdir: Optional[str]) -> email.message.EmailMessage: msg = email.message.EmailMessage() msg['From'] = '%s <%s>' % (jsondata['myname'], jsondata['myemail']) excludes = b4.get_excluded_addrs() @@ -114,7 +117,7 @@ def make_reply(reply_template, jsondata, gitdir): return msg -def auto_locate_pr(gitdir, jsondata, branch): +def auto_locate_pr(gitdir: Optional[str], jsondata: dict, branch: str) -> Optional[str]: pr_commit_id = jsondata['pr_commit_id'] logger.debug('Checking %s', jsondata['pr_commit_id']) if not b4.git_commit_exists(gitdir, pr_commit_id): @@ -150,7 +153,8 @@ def auto_locate_pr(gitdir, jsondata, branch): return merge_commit_id -def get_all_commits(gitdir, branch, since='1.week', committer=None): +def get_all_commits(gitdir: Optional[str], branch: str, since: str = '1.week', + committer: Optional[str] = None) -> Dict[str, Tuple[str, str, List[str]]]: global MY_COMMITS if MY_COMMITS is not None: return MY_COMMITS @@ -187,7 +191,8 @@ def get_all_commits(gitdir, branch, since='1.week', committer=None): return MY_COMMITS -def auto_locate_series(gitdir, jsondata, branch, since='1.week'): +def auto_locate_series(gitdir: Optional[str], jsondata: dict, branch: str, + since: str = '1.week') -> List[Tuple[int, Optional[str]]]: commits = get_all_commits(gitdir, branch, since) patchids = set(commits.keys()) @@ -230,7 +235,7 @@ def auto_locate_series(gitdir, jsondata, branch, since='1.week'): return found -def set_branch_details(gitdir, branch, jsondata, config): +def set_branch_details(gitdir: Optional[str], branch: str, jsondata: dict, config: dict) -> Tuple[dict, dict]: binfo = get_branch_info(gitdir, branch) jsondata['branch'] = branch for key, val in binfo.items(): @@ -262,7 +267,7 @@ def set_branch_details(gitdir, branch, jsondata, config): return jsondata, config -def generate_pr_thanks(gitdir, jsondata, branch): +def generate_pr_thanks(gitdir: Optional[str], jsondata: dict, branch: str) -> email.message.EmailMessage: config = b4.get_main_config() jsondata, config = set_branch_details(gitdir, branch, jsondata, config) thanks_template = DEFAULT_PR_TEMPLATE @@ -291,7 +296,7 @@ def generate_pr_thanks(gitdir, jsondata, branch): return msg -def generate_am_thanks(gitdir, jsondata, branch, since): +def generate_am_thanks(gitdir: Optional[str], jsondata: dict, branch: str, since: str) -> email.message.EmailMessage: config = b4.get_main_config() jsondata, config = set_branch_details(gitdir, branch, jsondata, config) thanks_template = DEFAULT_AM_TEMPLATE @@ -339,7 +344,7 @@ def generate_am_thanks(gitdir, jsondata, branch, since): return msg -def auto_thankanator(cmdargs): +def auto_thankanator(cmdargs: argparse.Namespace) -> None: gitdir = cmdargs.gitdir wantbranch = get_wanted_branch(cmdargs) logger.info('Auto-thankanating commits in %s', wantbranch) @@ -380,7 +385,7 @@ def auto_thankanator(cmdargs): sys.exit(0) -def send_messages(listing, branch, cmdargs): +def send_messages(listing: List[Dict], branch: str, cmdargs: argparse.Namespace) -> None: logger.info('Generating %s thank-you letters', len(listing)) gitdir = cmdargs.gitdir datadir = b4.get_data_dir() @@ -476,7 +481,7 @@ def send_messages(listing, branch, cmdargs): logger.info(' git send-email %s/*.thanks', cmdargs.outdir) -def list_tracked(): +def list_tracked() -> List[Dict]: # find all tracked bits tracked = list() datadir = b4.get_data_dir() @@ -493,7 +498,7 @@ def list_tracked(): return tracked -def write_tracked(tracked): +def write_tracked(tracked: List[Dict]) -> None: counter = 1 config = b4.get_main_config() logger.info('Currently tracking:') @@ -505,7 +510,7 @@ def write_tracked(tracked): counter += 1 -def thank_selected(cmdargs): +def thank_selected(cmdargs: argparse.Namespace) -> None: tracked = list_tracked() if not len(tracked): logger.info('Nothing to do') @@ -538,7 +543,7 @@ def thank_selected(cmdargs): sys.exit(0) -def discard_selected(cmdargs): +def discard_selected(cmdargs: argparse.Namespace) -> None: tracked = list_tracked() if not len(tracked): logger.info('Nothing to do') @@ -588,7 +593,7 @@ def discard_selected(cmdargs): sys.exit(0) -def check_stale_thanks(outdir): +def check_stale_thanks(outdir: str) -> None: if os.path.exists(outdir): for entry in Path(outdir).iterdir(): if entry.suffix == '.thanks': @@ -598,7 +603,7 @@ def check_stale_thanks(outdir): sys.exit(1) -def get_wanted_branch(cmdargs): +def get_wanted_branch(cmdargs: argparse.Namespace) -> str: global BRANCH_INFO gitdir = cmdargs.gitdir if not cmdargs.branch: @@ -622,7 +627,7 @@ def get_wanted_branch(cmdargs): return wantbranch -def get_branch_info(gitdir, branch): +def get_branch_info(gitdir: Optional[str], branch: str) -> Dict: global BRANCH_INFO if BRANCH_INFO is not None: return BRANCH_INFO @@ -663,7 +668,7 @@ def get_branch_info(gitdir, branch): return BRANCH_INFO -def main(cmdargs): +def main(cmdargs: argparse.Namespace) -> None: usercfg = b4.get_user_config() if 'email' not in usercfg: logger.critical('Please set user.email in gitconfig to use this feature.') |