I'm asking this in the context of Google Dataflow, but also generally.
Using PyTorch, I can reference a local directory containing multiple files that comprise a pretrained model. I happen to be working with a Roberta model, but the interface is the same for others.
ls some-directory/
added_tokens.json
config.json
merges.txt
pytorch_model.bin
special_tokens_map.json vocab.json
from pytorch_transformers import RobertaModel
# this works
model = RobertaModel.from_pretrained('/path/to/some-directory/')
However, my pretrained model is stored in a GCS bucket. Let's call it gs://my-bucket/roberta/
.
In the context of loading this model in Google Dataflow, I'm trying to remain stateless and avoid persisting to disk, so my preference would be to get this model straight from GCS. As I understand it, the PyTorch general interface method from_pretrained()
can take the string representation of a local dir OR a URL. However, I can't seem to load the model from a GCS URL.
# this fails
model = RobertaModel.from_pretrained('gs://my-bucket/roberta/')
# ValueError: unable to parse gs://mahmed_bucket/roberta-base as a URL or as a local path
If I try to use the public https URL of the directory blob, it will also fail, although that is likely due to lack of authentication since the credentials referenced in the python environment that can create clients don't translate to public requests to https://storage.googleapis
# this fails, probably due to auth
bucket = gcs_client.get_bucket('my-bucket')
directory_blob = bucket.blob(prefix='roberta')
model = RobertaModel.from_pretrained(directory_blob.public_url)
# ValueError: No JSON object could be decoded
# and for good measure, it also fails if I append a trailing /
model = RobertaModel.from_pretrained(directory_blob.public_url + '/')
# ValueError: No JSON object could be decoded
I understand that GCS doesn't actually have subdirectories and it's actually just being a flat namespace under the bucket name. However, it seems like I'm blocked by the necessity of authentication and a PyTorch not speaking gs://
.
I can get around this by persisting the files locally first.
from pytorch_transformers import RobertaModel
from google.cloud import storage
import tempfile
local_dir = tempfile.mkdtemp()
gcs = storage.Client()
bucket = gcs.get_bucket(bucket_name)
blobs = bucket.list_blobs(prefix=blob_prefix)
for blob in blobs:
blob.download_to_filename(local_dir + '/' + os.path.basename(blob.name))
model = RobertaModel.from_pretrained(local_dir)
But this seems like such a hack, and I keep thinking I must be missing something. Surely there's a way to stay stateless and not have to rely on disk persistence!
- So is there a way to load a pretrained model stored in GCS?
- Is there a way to authenticate when doing the public URL request in this context?
- Even if there is a way to authenticate, will the non-existence of subdirectories still be an issue?
Thanks for the help! I'm also happy to be pointed to any duplicate questions 'cause I sure couldn't find any.
Edits and Clarifications
My Python session is already authenticated to GCS, which is why I'm able to download the blob files locally and then point to that local directory with
load_frompretrained()
load_frompretrained()
requires a directory reference because it needs all the files listed at the top of the question, not justpytorch-model.bin
To clarify question #2, I was wondering if there's some way of giving the PyTorch method a request URL that had encrypted credentials embedded or something like that. Kind of a longshot, but I wanted to make sure I hadn't missed anything.
To clarify question #3 (in addition to the comment on one answer below), even if there's a way to embed credentials in the URL that I don't know about, I still need to reference a directory rather than a single blob, and I don't know if the GCS subdirectory would be recognized as such because (as the Google docs state) subdirectories in GCS are an illusion and they don't represent a real directory structure. So I think this question is irrelevant or at least blocked by question #2, but it's a thread I chased for a bit so I'm still curious.