aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMario Limonciello <mario.limonciello@amd.com>2023-10-05 22:19:32 -0500
committerMario Limonciello <mario.limonciello@amd.com>2023-10-13 13:50:00 -0500
commit4d6190714dc635b7c8b0fb384380c6fa66630711 (patch)
tree4df15b391a0fbe522513f8bd40fb23b0cc086a80
parent1be48f85d40871ff96368031123bb8facd499461 (diff)
downloadlinux-firmware-4d6190714dc635b7c8b0fb384380c6fa66630711.tar.gz
Add a script for a robot to open up pull requests
Signed-off-by: Mario Limonciello <mario.limonciello@amd.com>
-rw-r--r--.gitignore2
-rwxr-xr-xcheck_whence.py1
-rwxr-xr-xcontrib/process_linux_firmware.py290
3 files changed, 293 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
index a3876403..ea771274 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,5 @@
debian/
dist/
release/
+contrib/*.db
+contrib/*.txt
diff --git a/check_whence.py b/check_whence.py
index d824b3ef..4b5471e0 100755
--- a/check_whence.py
+++ b/check_whence.py
@@ -88,6 +88,7 @@ def main():
"contrib/templates/debian.control",
"contrib/templates/debian.copyright",
"contrib/templates/rpm.spec",
+ "contrib/process_linux_firmware.py",
]
)
known_prefixes = set(name for name in whence_list if name.endswith("/"))
diff --git a/contrib/process_linux_firmware.py b/contrib/process_linux_firmware.py
new file mode 100755
index 00000000..6341fca1
--- /dev/null
+++ b/contrib/process_linux_firmware.py
@@ -0,0 +1,290 @@
+#!/usr/bin/python3
+import os
+import time
+import urllib.request
+import sqlite3
+import feedparser
+import argparse
+import logging
+import email
+import subprocess
+import sys
+from datetime import datetime, timedelta, date
+from enum import Enum
+import b4
+
+URL = "https://lore.kernel.org/linux-firmware/new.atom"
+
+
+class ContentType(Enum):
+ REPLY = 1
+ PATCH = 2
+ PULL_REQUEST = 3
+ SPAM = 4
+
+
+content_types = {
+ "diff --git": ContentType.PATCH,
+ "Signed-off-by:": ContentType.PATCH,
+ "are available in the Git repository at": ContentType.PULL_REQUEST,
+}
+
+
+def classify_content(content):
+ # load content into the email library
+ msg = email.message_from_string(content)
+
+ # check the subject
+ subject = msg["Subject"]
+ if "Re:" in subject:
+ return ContentType.REPLY
+ if "PATCH" in subject:
+ return ContentType.PATCH
+
+ for part in msg.walk():
+ if part.get_content_type() == "text/plain":
+ body = part.get_payload(decode=True).decode("utf-8")
+ for key in content_types.keys():
+ if key in body:
+ return content_types[key]
+ break
+ return ContentType.SPAM
+
+
+def fetch_url(url):
+ with urllib.request.urlopen(url) as response:
+ return response.read().decode("utf-8")
+
+
+def quiet_cmd(cmd):
+ logging.debug("Running {}".format(cmd))
+ output = subprocess.check_output(cmd, stderr=subprocess.STDOUT)
+ logging.debug(output)
+
+
+def create_pr(remote, branch):
+ cmd = [
+ "git",
+ "push",
+ "-u",
+ remote,
+ branch,
+ "-o",
+ "merge_request.create",
+ "-o",
+ "merge_request.remove_source_branch",
+ "-o",
+ "merge_request.target=main",
+ "-o",
+ "merge_request.title={}".format(branch),
+ ]
+ quiet_cmd(cmd)
+
+
+def refresh_branch():
+ quiet_cmd(["git", "checkout", "main"])
+ quiet_cmd(["git", "pull"])
+
+
+def delete_branch(branch):
+ quiet_cmd(["git", "checkout", "main"])
+ quiet_cmd(["git", "branch", "-D", branch])
+
+
+def process_pr(url, num, remote):
+ branch = "robot/pr-{}-{}".format(num, int(time.time()))
+ cmd = ["b4", "pr", "-b", branch, url]
+ try:
+ quiet_cmd(cmd)
+ except subprocess.CalledProcessError:
+ logging.warning("Failed to apply PR")
+ return
+
+ # determine if it worked (we can't tell unfortunately by return code)
+ cmd = ["git", "branch", "--list", branch]
+ logging.debug("Running {}".format(cmd))
+ result = subprocess.check_output(cmd)
+
+ if result:
+ logging.info("Forwarding PR for {}".format(branch))
+ if remote:
+ create_pr(remote, branch)
+ delete_branch(branch)
+
+
+def process_patch(mbox, num, remote):
+ # create a new branch for the patch
+ branch = "robot/patch-{}-{}".format(num, int(time.time()))
+ cmd = ["git", "checkout", "-b", branch]
+ quiet_cmd(cmd)
+
+ # apply the patch
+ cmd = ["b4", "shazam", "-m", "-"]
+ logging.debug("Running {}".format(cmd))
+ p = subprocess.Popen(
+ cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE
+ )
+ stdout, stderr = p.communicate(mbox.encode("utf-8"))
+ for line in stdout.splitlines():
+ logging.debug(line.decode("utf-8"))
+ for line in stderr.splitlines():
+ logging.debug(line.decode("utf-8"))
+ if p.returncode != 0:
+ quiet_cmd(["git", "am", "--abort"])
+ else:
+ logging.info("Opening PR for {}".format(branch))
+ if remote:
+ create_pr(remote, branch)
+
+ delete_branch(branch)
+
+
+def update_database(conn, url):
+ c = conn.cursor()
+
+ c.execute(
+ """CREATE TABLE IF NOT EXISTS firmware (url text, processed integer default 0, spam integer default 0)"""
+ )
+
+ # local file
+ if os.path.exists(url):
+ with open(url, "r") as f:
+ atom = f.read()
+ # remote file
+ else:
+ logging.info("Fetching {}".format(url))
+ atom = fetch_url(url)
+
+ # Parse the atom and extract the URLs
+ feed = feedparser.parse(atom)
+
+ # Insert the URLs into the database (oldest first)
+ feed["entries"].reverse()
+ for entry in feed["entries"]:
+ c.execute("SELECT url FROM firmware WHERE url = ?", (entry.link,))
+ if c.fetchone():
+ continue
+ c.execute("INSERT INTO firmware VALUES (?, ?, ?)", (entry.link, 0, 0))
+
+ # Commit the changes and close the connection
+ conn.commit()
+
+
+def process_database(conn, remote):
+ c = conn.cursor()
+
+ # get all unprocessed urls that aren't spam
+ c.execute("SELECT url FROM firmware WHERE processed = 0 AND spam = 0")
+ num = 0
+ msg = ""
+
+ rows = c.fetchall()
+
+ if not rows:
+ logging.info("No new entries")
+ return
+
+ refresh_branch()
+
+ # loop over all unprocessed urls
+ for row in rows:
+
+ msg = "Processing ({}%)".format(round(num / len(rows) * 100))
+ print(msg, end="\r", flush=True)
+
+ url = "{}raw".format(row[0])
+ logging.debug("Processing {}".format(url))
+ mbox = fetch_url(url)
+ classification = classify_content(mbox)
+
+ if classification == ContentType.PATCH:
+ logging.debug("Processing patch ({})".format(row[0]))
+ process_patch(mbox, num, remote)
+
+ if classification == ContentType.PULL_REQUEST:
+ logging.debug("Processing PR ({})".format(row[0]))
+ process_pr(row[0], num, remote)
+
+ if classification == ContentType.SPAM:
+ logging.debug("Marking spam ({})".format(row[0]))
+ c.execute("UPDATE firmware SET spam = 1 WHERE url = ?", (row[0],))
+
+ if classification == ContentType.REPLY:
+ logging.debug("Ignoring reply ({})".format(row[0]))
+
+ c.execute("UPDATE firmware SET processed = 1 WHERE url = ?", (row[0],))
+ num += 1
+ print(" " * len(msg), end="\r", flush=True)
+
+ # commit changes
+ conn.commit()
+ logging.info("Finished processing {} new entries".format(len(rows)))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Process linux-firmware mailing list")
+ parser.add_argument("--url", default=URL, help="URL to get ATOM feed from")
+ parser.add_argument(
+ "--database",
+ default=os.path.join("contrib", "linux_firmware.db"),
+ help="sqlite database to store entries in",
+ )
+ parser.add_argument("--dry", action="store_true", help="Don't open pull requests")
+ parser.add_argument(
+ "--debug", action="store_true", help="Enable debug logging to console"
+ )
+ parser.add_argument("--remote", default="origin", help="Remote to push to")
+ parser.add_argument(
+ "--refresh-cycle", default=0, help="How frequently to run (in minutes)"
+ )
+ args = parser.parse_args()
+
+ if not os.path.exists("WHENCE"):
+ logging.critical(
+ "Please run this script from the root of the linux-firmware repository"
+ )
+ sys.exit(1)
+
+ log = os.path.join(
+ "contrib",
+ "{prefix}-{date}.{suffix}".format(
+ prefix="linux_firmware", suffix="txt", date=date.today()
+ ),
+ )
+ logging.basicConfig(
+ format="%(asctime)s %(levelname)s:\t%(message)s",
+ filename=log,
+ filemode="w",
+ level=logging.DEBUG,
+ )
+
+ # set a format which is simpler for console use
+ console = logging.StreamHandler()
+ if args.debug:
+ console.setLevel(logging.DEBUG)
+ else:
+ console.setLevel(logging.INFO)
+ formatter = logging.Formatter("%(asctime)s : %(levelname)s : %(message)s")
+ console.setFormatter(formatter)
+ logging.getLogger("").addHandler(console)
+
+ while True:
+ conn = sqlite3.connect(args.database)
+ # update the database
+ update_database(conn, args.url)
+
+ if args.dry:
+ remote = ""
+ else:
+ remote = args.remote
+
+ # process the database
+ process_database(conn, remote)
+
+ conn.close()
+
+ if args.refresh_cycle:
+ logging.info("Sleeping for {} minutes".format(args.refresh_cycle))
+ time.sleep(int(args.refresh_cycle) * 60)
+ else:
+ break