Parcourir la source

Implement s3:// protocol (#11511)

For those that want to pull from s3

Signed-off-by: Eric Curtin <ecurtin@redhat.com>
Eric Curtin il y a 11 mois
Parent
commit
ecef206ccb
1 fichiers modifiés avec 43 ajouts et 0 suppressions
  1. 43 0
      examples/run/run.cpp

+ 43 - 0
examples/run/run.cpp

@@ -65,6 +65,13 @@ static int printe(const char * fmt, ...) {
     return ret;
     return ret;
 }
 }
 
 
+static std::string strftime_fmt(const char * fmt, const std::tm & tm) {
+    std::ostringstream oss;
+    oss << std::put_time(&tm, fmt);
+
+    return oss.str();
+}
+
 class Opt {
 class Opt {
   public:
   public:
     int init(int argc, const char ** argv) {
     int init(int argc, const char ** argv) {
@@ -698,6 +705,39 @@ class LlamaData {
         return download(url, bn, true);
         return download(url, bn, true);
     }
     }
 
 
+    int s3_dl(const std::string & model, const std::string & bn) {
+        const size_t slash_pos = model.find('/');
+        if (slash_pos == std::string::npos) {
+            return 1;
+        }
+
+        const std::string bucket     = model.substr(0, slash_pos);
+        const std::string key        = model.substr(slash_pos + 1);
+        const char * access_key = std::getenv("AWS_ACCESS_KEY_ID");
+        const char * secret_key = std::getenv("AWS_SECRET_ACCESS_KEY");
+        if (!access_key || !secret_key) {
+            printe("AWS credentials not found in environment\n");
+            return 1;
+        }
+
+        // Generate AWS Signature Version 4 headers
+        // (Implementation requires HMAC-SHA256 and date handling)
+        // Get current timestamp
+        const time_t                   now     = time(nullptr);
+        const tm                       tm      = *gmtime(&now);
+        const std::string              date     = strftime_fmt("%Y%m%d", tm);
+        const std::string              datetime = strftime_fmt("%Y%m%dT%H%M%SZ", tm);
+        const std::vector<std::string> headers  = {
+            "Authorization: AWS4-HMAC-SHA256 Credential=" + std::string(access_key) + "/" + date +
+                "/us-east-1/s3/aws4_request",
+            "x-amz-content-sha256: UNSIGNED-PAYLOAD", "x-amz-date: " + datetime
+        };
+
+        const std::string url = "https://" + bucket + ".s3.amazonaws.com/" + key;
+
+        return download(url, bn, true, headers);
+    }
+
     std::string basename(const std::string & path) {
     std::string basename(const std::string & path) {
         const size_t pos = path.find_last_of("/\\");
         const size_t pos = path.find_last_of("/\\");
         if (pos == std::string::npos) {
         if (pos == std::string::npos) {
@@ -738,6 +778,9 @@ class LlamaData {
             rm_until_substring(model_, "github:");
             rm_until_substring(model_, "github:");
             rm_until_substring(model_, "://");
             rm_until_substring(model_, "://");
             ret = github_dl(model_, bn);
             ret = github_dl(model_, bn);
+        } else if (string_starts_with(model_, "s3://")) {
+            rm_until_substring(model_, "://");
+            ret = s3_dl(model_, bn);
         } else {  // ollama:// or nothing
         } else {  // ollama:// or nothing
             rm_until_substring(model_, "ollama.com/library/");
             rm_until_substring(model_, "ollama.com/library/");
             rm_until_substring(model_, "://");
             rm_until_substring(model_, "://");