From 80190be8d9c82ed816fb571abef416a1fbfb9a35 Mon Sep 17 00:00:00 2001
From: Hongxu Jia <hongxu.jia@windriver.com>
Date: Tue, 31 Jul 2018 17:24:47 +0800
Subject: [PATCH] support authentication for kickstart

While download kickstart file from web server,
we support basic/digest authentication.

Add KickstartAuthError to report authentication failure,
which the invoker could parse this specific error.

Upstream-Status: Inappropriate [oe specific]

Signed-off-by: Hongxu Jia <hongxu.jia@windriver.com>

Rebase to 3.62
Signed-off-by: Mingli Yu <mingli.yu@windriver.com>
---
 pykickstart/errors.py | 17 +++++++++++++++++
 pykickstart/load.py   | 33 ++++++++++++++++++++++++++++-----
 pykickstart/parser.py |  4 ++--
 3 files changed, 47 insertions(+), 7 deletions(-)

diff --git a/pykickstart/errors.py b/pykickstart/errors.py
index 8294f59..3d20bf8 100644
--- a/pykickstart/errors.py
+++ b/pykickstart/errors.py
@@ -32,6 +32,9 @@ This module exports several exception classes:
     KickstartVersionError - An exception for errors relating to unsupported
                             syntax versions.
 
+    KickstartAuthError - An exception for errors relating to authentication
+                         failed while downloading kickstart from web server
+
 And some warning classes:
 
     KickstartWarning - A generic warning class.
@@ -125,3 +128,17 @@ class KickstartDeprecationWarning(KickstartParseWarning, DeprecationWarning):
     """A class for warnings occurring during parsing related to using deprecated
        commands and options.
     """
+
+class KickstartAuthError(KickstartError):
+    """An exception for errors relating to authentication failed while
+       downloading kickstart from web server
+    """
+    def __init__(self, msg):
+        """Create a new KickstartAuthError exception instance with the
+           descriptive message val. val should be the return value of
+           formatErrorMsg.
+        """
+        KickstartError.__init__(self, msg)
+
+    def __str__(self):
+        return self.value
diff --git a/pykickstart/load.py b/pykickstart/load.py
index e8301a4..45d402a 100644
--- a/pykickstart/load.py
+++ b/pykickstart/load.py
@@ -18,9 +18,11 @@
 # with the express permission of Red Hat, Inc.
 #
 import requests
+from requests.auth import HTTPDigestAuth
+from requests.auth import HTTPBasicAuth
 import shutil
 
-from pykickstart.errors import KickstartError
+from pykickstart.errors import KickstartError, KickstartAuthError
 from pykickstart.i18n import _
 from requests.exceptions import SSLError, RequestException
 
@@ -28,7 +30,7 @@ is_url = lambda location: '://' in location  # RFC 3986
 
 SSL_VERIFY = True
 
-def load_to_str(location):
+def load_to_str(location, user=None, passwd=None):
     '''Load a destination URL or file into a string.
     Type of input is inferred automatically.
 
@@ -39,7 +41,7 @@ def load_to_str(location):
     Raises: KickstartError on error reading'''
 
     if is_url(location):
-        return _load_url(location)
+        return _load_url(location, user=user, passwd=passwd)
     else:
         return _load_file(location)
 
@@ -69,11 +71,32 @@ def load_to_file(location, destination):
         _copy_file(location, destination)
         return destination
 
-def _load_url(location):
+def _get_auth(location, user=None, passwd=None):
+
+    auth = None
+    request = requests.get(location, verify=SSL_VERIFY)
+    if request.status_code == requests.codes.unauthorized:
+        if user is None or passwd is None:
+            log.info("Require Authentication")
+            raise KickstartAuthError("Require Authentication.\nAppend 'ksuser=<username> kspasswd=<password>' to boot command")
+
+        reasons = request.headers.get("WWW-Authenticate", "").split()
+        if reasons:
+            auth_type = reasons[0]
+        if auth_type == "Basic":
+            auth = HTTPBasicAuth(user, passwd)
+        elif auth_type == "Digest":
+            auth=HTTPDigestAuth(user, passwd)
+
+    return auth
+
+def _load_url(location, user=None, passwd=None):
     '''Load a location (URL or filename) and return contents as string'''
 
+    auth = _get_auth(location, user=user, passwd=passwd)
+
     try:
-        request = requests.get(location, verify=SSL_VERIFY, timeout=120)
+        request = requests.get(location, verify=SSL_VERIFY, auth=auth, timeout=120)
     except SSLError as e:
         raise KickstartError(_('Error securely accessing URL "%s"') % location + ': {e}'.format(e=str(e)))
     except RequestException as e:
diff --git a/pykickstart/parser.py b/pykickstart/parser.py
index 12b0467..351dc1b 100644
--- a/pykickstart/parser.py
+++ b/pykickstart/parser.py
@@ -831,7 +831,7 @@ class KickstartParser(object):
         i = PutBackIterator(s.splitlines(True) + [""])
         self._stateMachine(i)
 
-    def readKickstart(self, f, reset=True):
+    def readKickstart(self, f, reset=True, username=None, password=None):
         """Process a kickstart file, given by the filename f."""
         if reset:
             self._reset()
@@ -852,7 +852,7 @@ class KickstartParser(object):
         self.currentdir[self._includeDepth] = cd
 
         try:
-            s = load_to_str(f)
+            s = load_to_str(f, user=username, passwd=password)
         except KickstartError as e:
             raise KickstartError(_("Unable to open input kickstart file: %s") % str(e), lineno=0)
 
-- 
2.34.1

