1 Commits

Author SHA1 Message Date
Oscar Krause
c6ca1cafb9 serve drivers directly via api if configured 2025-04-10 09:02:41 +02:00
8 changed files with 225 additions and 444 deletions

View File

@@ -16,12 +16,12 @@ build:docker:
interruptible: true
stage: build
rules:
# deployment is in "deploy:docker:"
- if: $CI_PIPELINE_SOURCE == 'merge_request_event'
- if: $CI_COMMIT_BRANCH && $CI_COMMIT_BRANCH != $CI_DEFAULT_BRANCH
changes:
- app/**/*
- Dockerfile
- requirements.txt
- if: $CI_PIPELINE_SOURCE == 'merge_request_event'
tags: [ docker ]
before_script:
- docker buildx inspect
@@ -44,13 +44,16 @@ build:apt:
- if: $CI_COMMIT_TAG
variables:
VERSION: $CI_COMMIT_REF_NAME
- if: ($CI_PIPELINE_SOURCE == 'merge_request_event') || ($CI_COMMIT_BRANCH && $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH)
- if: $CI_COMMIT_BRANCH && $CI_COMMIT_BRANCH != $CI_DEFAULT_BRANCH
changes:
- app/**/*
- .DEBIAN/**/*
- .gitlab-ci.yml
variables:
VERSION: "0.0.1"
- if: $CI_PIPELINE_SOURCE == 'merge_request_event'
variables:
VERSION: "0.0.1"
before_script:
- echo -e "VERSION=$VERSION\nCOMMIT=$CI_COMMIT_SHA" > version.env
# install build dependencies
@@ -91,13 +94,16 @@ build:pacman:
- if: $CI_COMMIT_TAG
variables:
VERSION: $CI_COMMIT_REF_NAME
- if: ($CI_PIPELINE_SOURCE == 'merge_request_event') || ($CI_COMMIT_BRANCH && $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH)
- if: $CI_COMMIT_BRANCH && $CI_COMMIT_BRANCH != $CI_DEFAULT_BRANCH
changes:
- app/**/*
- .PKGBUILD/**/*
- .gitlab-ci.yml
variables:
VERSION: "0.0.1"
- if: $CI_PIPELINE_SOURCE == 'merge_request_event'
variables:
VERSION: "0.0.1"
before_script:
#- echo -e "VERSION=$VERSION\nCOMMIT=$CI_COMMIT_SHA" > version.env
# install build dependencies
@@ -120,12 +126,13 @@ build:pacman:
paths:
- "*.pkg.tar.zst"
test:python:
image: $IMAGE
test:
image: python:3.12-slim-bookworm
stage: test
interruptible: true
rules:
- if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH
- if: $CI_COMMIT_TAG
- if: $CI_PIPELINE_SOURCE == "merge_request_event"
- if: $CI_COMMIT_BRANCH && $CI_COMMIT_BRANCH != $CI_DEFAULT_BRANCH
changes:
@@ -135,20 +142,17 @@ test:python:
DATABASE: sqlite:///../app/db.sqlite
parallel:
matrix:
- IMAGE:
# https://devguide.python.org/versions/#supported-versions
# - python:3.14-rc-alpine # EOL 2030-10 => uvicorn does not support 3.14 yet
- python:3.13-alpine # EOL 2029-10
- python:3.12-alpine # EOL 2028-10
- python:3.11-alpine # EOL 2027-10
# - python:3.10-alpine # EOL 2026-10 => ImportError: cannot import name 'UTC' from 'datetime'
# - python:3.9-alpine # EOL 2025-10 => ImportError: cannot import name 'UTC' from 'datetime'
- REQUIREMENTS:
- 'requirements.txt'
# - '.DEBIAN/requirements-bookworm-12.txt'
# - '.DEBIAN/requirements-ubuntu-24.04.txt'
# - '.DEBIAN/requirements-ubuntu-24.10.txt'
before_script:
- apk --no-cache add openssl
- apt-get update && apt-get install -y python3-dev python3-pip python3-venv gcc
- python3 -m venv venv
- source venv/bin/activate
- pip install --upgrade pip
- pip install -r requirements.txt
- pip install -r $REQUIREMENTS
- pip install pytest pytest-cov pytest-custom_exit_code httpx
- mkdir -p app/cert
- openssl genrsa -out app/cert/instance.private.pem 2048
@@ -158,26 +162,17 @@ test:python:
- python -m pytest main.py --junitxml=report.xml
artifacts:
reports:
dotenv: version.env
junit: ['**/report.xml']
test:apt:
image: $IMAGE
.test:apt:
stage: test
rules:
- if: $CI_COMMIT_BRANCH && $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH
- if: $CI_PIPELINE_SOURCE == 'merge_request_event'
- if: $CI_COMMIT_BRANCH && $CI_COMMIT_BRANCH != $CI_DEFAULT_BRANCH
changes:
- app/**/*
- .DEBIAN/**/*
- if: $CI_PIPELINE_SOURCE == 'merge_request_event'
parallel:
matrix:
- IMAGE:
- debian:trixie-slim # EOL: t.b.a.
- debian:bookworm-slim # EOL: June 06, 2026
- debian:bookworm-slim # EOL: June 06, 2026
- ubuntu:24.04 # EOL: April 2036
- ubuntu:24.10
needs:
- job: build:apt
artifacts: true
@@ -209,15 +204,24 @@ test:apt:
- apt-get purge -qq -y fastapi-dls
- apt-get autoremove -qq -y && apt-get clean -qq
test:apt:
extends: .test:apt
image: $IMAGE
parallel:
matrix:
- IMAGE:
- debian:bookworm-slim # EOL: June 06, 2026
- ubuntu:24.04 # EOL: April 2036
- ubuntu:24.10
test:pacman:archlinux:
image: archlinux:base
rules:
- if: $CI_COMMIT_BRANCH && $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH
- if: $CI_PIPELINE_SOURCE == 'merge_request_event'
- if: $CI_COMMIT_BRANCH && $CI_COMMIT_BRANCH != $CI_DEFAULT_BRANCH
changes:
- app/**/*
- .PKGBUILD/**/*
- .gitlab-ci.yml
- if: $CI_PIPELINE_SOURCE == 'merge_request_event'
needs:
- job: build:pacman
artifacts: true
@@ -292,12 +296,15 @@ gemnasium-python-dependency_scanning:
- if: $CI_PIPELINE_SOURCE == "merge_request_event"
- if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH
.deploy:
rules:
- if: $CI_COMMIT_TAG
deploy:docker:
extends: .deploy
image: docker:dind
stage: deploy
tags: [ docker ]
rules:
- if: $CI_COMMIT_TAG
before_script:
- echo "Building docker image for commit $CI_COMMIT_SHA with version $CI_COMMIT_REF_NAME"
- docker buildx inspect
@@ -316,10 +323,9 @@ deploy:docker:
deploy:apt:
# doc: https://git.collinwebdesigns.de/help/user/packages/debian_repository/index.md#install-a-package
extends: .deploy
image: debian:bookworm-slim
stage: deploy
rules:
- if: $CI_COMMIT_TAG
needs:
- job: build:apt
artifacts: true
@@ -356,10 +362,9 @@ deploy:apt:
- 'curl --header "JOB-TOKEN: $CI_JOB_TOKEN" --upload-file ${EXPORT_NAME} "${CI_API_V4_URL}/projects/${CI_PROJECT_ID}/packages/generic/${PACKAGE_NAME}/${PACKAGE_VERSION}/${EXPORT_NAME}"'
deploy:pacman:
extends: .deploy
image: archlinux:base-devel
stage: deploy
rules:
- if: $CI_COMMIT_TAG
needs:
- job: build:pacman
artifacts: true
@@ -380,7 +385,7 @@ deploy:pacman:
release:
image: registry.gitlab.com/gitlab-org/release-cli:latest
stage: .post
needs: [ build:docker, build:apt, build:pacman ]
needs: [ test ]
rules:
- if: $CI_COMMIT_TAG
script:

View File

@@ -795,13 +795,13 @@ Thanks to vGPU community and all who uses this project and report bugs.
Special thanks to:
- `samicrusader` who created build file for **ArchLinux**
- `cyrus` who wrote the section for **openSUSE**
- `midi` who wrote the section for **unRAID**
- `polloloco` who wrote the *[NVIDIA vGPU Guide](https://gitlab.com/polloloco/vgpu-proxmox)*
- `DualCoder` who creates the `vgpu_unlock` functionality [vgpu_unlock](https://github.com/DualCoder/vgpu_unlock)
- `Krutav Shah` who wrote the [vGPU_Unlock Wiki](https://docs.google.com/document/d/1pzrWJ9h-zANCtyqRgS7Vzla0Y8Ea2-5z2HEi4X75d2Q/)
- `Wim van 't Hoog` for the [Proxmox All-In-One Installer Script](https://wvthoog.nl/proxmox-vgpu-v3/)
- `mrzenc` who wrote [fastapi-dls-nixos](https://github.com/mrzenc/fastapi-dls-nixos)
- @samicrusader who created build file for **ArchLinux**
- @cyrus who wrote the section for **openSUSE**
- @midi who wrote the section for **unRAID**
- @polloloco who wrote the *[NVIDIA vGPU Guide](https://gitlab.com/polloloco/vgpu-proxmox)*
- @DualCoder who creates the `vgpu_unlock` functionality [vgpu_unlock](https://github.com/DualCoder/vgpu_unlock)
- Krutav Shah who wrote the [vGPU_Unlock Wiki](https://docs.google.com/document/d/1pzrWJ9h-zANCtyqRgS7Vzla0Y8Ea2-5z2HEi4X75d2Q/)
- Wim van 't Hoog for the [Proxmox All-In-One Installer Script](https://wvthoog.nl/proxmox-vgpu-v3/)
- @mrzenc who wrote [fastapi-dls-nixos](https://github.com/mrzenc/fastapi-dls-nixos)
And thanks to all people who contributed to all these libraries!

View File

@@ -1,12 +1,12 @@
import logging
import sys
import os.path
from base64 import b64encode as b64enc
from calendar import timegm
from contextlib import asynccontextmanager
from datetime import datetime, UTC
from datetime import datetime, timedelta, UTC
from hashlib import sha256
from json import loads as json_loads
from os import getenv as env
from os import getenv as env, listdir
from os.path import join, dirname
from uuid import uuid4
@@ -14,14 +14,16 @@ from dateutil.relativedelta import relativedelta
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.requests import Request
from jose import jws, jwt, JWTError
from fastapi.staticfiles import StaticFiles
from jose import jws, jwk, jwt, JWTError
from jose.constants import ALGORITHMS
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from starlette.middleware.cors import CORSMiddleware
from starlette.responses import StreamingResponse, JSONResponse as JSONr, HTMLResponse as HTMLr, Response, RedirectResponse
from orm import Origin, Lease, init as db_init, migrate, Instance, Site
from orm import Origin, Lease, init as db_init, migrate
from util import PrivateKey, PublicKey, load_file
# Load variables
load_dotenv('../version.env')
@@ -39,9 +41,21 @@ db_init(db), migrate(db)
# Load DLS variables (all prefixed with "INSTANCE_*" is used as "SERVICE_INSTANCE_*" or "SI_*" in official dls service)
DLS_URL = str(env('DLS_URL', 'localhost'))
DLS_PORT = int(env('DLS_PORT', '443'))
SITE_KEY_XID = str(env('SITE_KEY_XID', '00000000-0000-0000-0000-000000000000'))
INSTANCE_REF = str(env('INSTANCE_REF', '10000000-0000-0000-0000-000000000001'))
ALLOTMENT_REF = str(env('ALLOTMENT_REF', '20000000-0000-0000-0000-000000000001'))
INSTANCE_KEY_RSA = PrivateKey.from_file(str(env('INSTANCE_KEY_RSA', join(dirname(__file__), 'cert/instance.private.pem'))))
INSTANCE_KEY_PUB = PublicKey.from_file(str(env('INSTANCE_KEY_PUB', join(dirname(__file__), 'cert/instance.public.pem'))))
TOKEN_EXPIRE_DELTA = relativedelta(days=int(env('TOKEN_EXPIRE_DAYS', 1)), hours=int(env('TOKEN_EXPIRE_HOURS', 0)))
LEASE_EXPIRE_DELTA = relativedelta(days=int(env('LEASE_EXPIRE_DAYS', 90)), hours=int(env('LEASE_EXPIRE_HOURS', 0)))
LEASE_RENEWAL_PERIOD = float(env('LEASE_RENEWAL_PERIOD', 0.15))
LEASE_RENEWAL_DELTA = timedelta(days=int(env('LEASE_EXPIRE_DAYS', 90)), hours=int(env('LEASE_EXPIRE_HOURS', 0)))
CLIENT_TOKEN_EXPIRE_DELTA = relativedelta(years=12)
CORS_ORIGINS = str(env('CORS_ORIGINS', '')).split(',') if (env('CORS_ORIGINS')) else [f'https://{DLS_URL}']
DRIVERS_DIR = env('DRIVERS_DIR', None)
ALLOTMENT_REF = str(env('ALLOTMENT_REF', '20000000-0000-0000-0000-000000000001')) # todo
jwt_encode_key = jwk.construct(INSTANCE_KEY_RSA.pem(), algorithm=ALGORITHMS.RS256)
jwt_decode_key = jwk.construct(INSTANCE_KEY_PUB.pem(), algorithm=ALGORITHMS.RS256)
# Logging
LOG_LEVEL = logging.DEBUG if DEBUG else logging.INFO
@@ -49,33 +63,25 @@ logging.basicConfig(format='[{levelname:^7}] [{module:^15}] {message}', style='{
logger = logging.getLogger(__name__)
logger.setLevel(LOG_LEVEL)
logging.getLogger('util').setLevel(LOG_LEVEL)
logging.getLogger('DriverMatrix').setLevel(LOG_LEVEL)
logging.getLogger('NV').setLevel(LOG_LEVEL)
# FastAPI
@asynccontextmanager
async def lifespan(_: FastAPI):
# on startup
default_instance = Instance.get_default_instance(db)
lease_renewal_period = default_instance.lease_renewal_period
lease_renewal_delta = default_instance.get_lease_renewal_delta()
client_token_expire_delta = default_instance.get_client_token_expire_delta()
logger.info(f'''
Using timezone: {str(TZ)}. Make sure this is correct and match your clients!
Your clients will renew their license every {str(Lease.calculate_renewal(lease_renewal_period, lease_renewal_delta))}.
If the renewal fails, the license is valid for {str(lease_renewal_delta)}.
Your clients renew their license every {str(Lease.calculate_renewal(LEASE_RENEWAL_PERIOD, LEASE_RENEWAL_DELTA))}.
If the renewal fails, the license is {str(LEASE_RENEWAL_DELTA)} valid.
Your client-token file (.tok) is valid for {str(client_token_expire_delta)}.
Your client-token file (.tok) is valid for {str(CLIENT_TOKEN_EXPIRE_DELTA)}.
''')
logger.info(f'Debug is {"enabled" if DEBUG else "disabled"}.')
validate_settings()
yield
# on shutdown
@@ -85,6 +91,9 @@ async def lifespan(_: FastAPI):
config = dict(openapi_url=None, docs_url=None, redoc_url=None) # dict(openapi_url='/-/openapi.json', docs_url='/-/docs', redoc_url='/-/redoc')
app = FastAPI(title='FastAPI-DLS', description='Minimal Delegated License Service (DLS).', version=VERSION, lifespan=lifespan, **config)
if DRIVERS_DIR is not None:
app.mount('/-/static-drivers', StaticFiles(directory=str(DRIVERS_DIR), html=False), name='drivers')
app.debug = DEBUG
app.add_middleware(
CORSMiddleware,
@@ -96,24 +105,12 @@ app.add_middleware(
# Helper
def __get_token(request: Request, jwt_decode_key: "jose.jwt") -> dict:
def __get_token(request: Request) -> dict:
authorization_header = request.headers.get('authorization')
token = authorization_header.split(' ')[1]
return jwt.decode(token=token, key=jwt_decode_key, algorithms=ALGORITHMS.RS256, options={'verify_aud': False})
def validate_settings():
session = sessionmaker(bind=db)()
lease_expire_delta_min, lease_expire_delta_max = 86_400, 7_776_000
for instance in session.query(Instance).all():
lease_expire_delta = instance.lease_expire_delta
if lease_expire_delta < 86_400 or lease_expire_delta > 7_776_000:
logging.warning(f'> [ instance ]: {instance.instance_ref}: "lease_expire_delta" should be between {lease_expire_delta_min} and {lease_expire_delta_max}')
session.close()
# Endpoints
@app.get('/', summary='Index')
@@ -133,20 +130,18 @@ async def _health():
@app.get('/-/config', summary='* Config', description='returns environment variables.')
async def _config():
default_site, default_instance = Site.get_default_site(db), Instance.get_default_instance(db)
return JSONr({
'VERSION': str(VERSION),
'COMMIT': str(COMMIT),
'DEBUG': str(DEBUG),
'DLS_URL': str(DLS_URL),
'DLS_PORT': str(DLS_PORT),
'SITE_KEY_XID': str(default_site.site_key),
'INSTANCE_REF': str(default_instance.instance_ref),
'SITE_KEY_XID': str(SITE_KEY_XID),
'INSTANCE_REF': str(INSTANCE_REF),
'ALLOTMENT_REF': [str(ALLOTMENT_REF)],
'TOKEN_EXPIRE_DELTA': str(default_instance.get_token_expire_delta()),
'LEASE_EXPIRE_DELTA': str(default_instance.get_lease_expire_delta()),
'LEASE_RENEWAL_PERIOD': str(default_instance.lease_renewal_period),
'TOKEN_EXPIRE_DELTA': str(TOKEN_EXPIRE_DELTA),
'LEASE_EXPIRE_DELTA': str(LEASE_EXPIRE_DELTA),
'LEASE_RENEWAL_PERIOD': str(LEASE_RENEWAL_PERIOD),
'CORS_ORIGINS': str(CORS_ORIGINS),
'TZ': str(TZ),
})
@@ -155,7 +150,6 @@ async def _config():
@app.get('/-/readme', summary='* Readme')
async def _readme():
from markdown import markdown
from util import load_file
content = load_file(join(dirname(__file__), '../README.md')).decode('utf-8')
return HTMLr(markdown(text=content, extensions=['tables', 'fenced_code', 'md_in_html', 'nl2br', 'toc']))
@@ -198,6 +192,25 @@ async def _manage(request: Request):
return HTMLr(response)
@app.get('/-/drivers/{directory:path}', summary='* List drivers directory')
async def _drivers(request: Request, directory: str | None):
if DRIVERS_DIR is None:
return Response(status_code=404, content=f'Variable "DRIVERS_DIR" not set.')
path = os.path.join(DRIVERS_DIR, directory)
if not os.path.exists(path) and not os.path.isfile(path):
return Response(status_code=404, content=f'Resource "{path}" not found!')
content = [{
"type": "file" if os.path.isfile(f'{path}/{_}') else "folder" if os.path.isdir(f'{path}/{_}') else "unknown",
"name": _,
"link": f'/-/static-drivers/{directory}{_}',
} for _ in listdir(path)]
return JSONr({"directory": path, "content": content})
@app.get('/-/origins', summary='* Origins')
async def _origins(request: Request, leases: bool = False):
session = sessionmaker(bind=db)()
@@ -205,7 +218,8 @@ async def _origins(request: Request, leases: bool = False):
for origin in session.query(Origin).all():
x = origin.serialize()
if leases:
x['leases'] = list(map(lambda _: _.serialize(), Lease.find_by_origin_ref(db, origin.origin_ref)))
serialize = dict(renewal_period=LEASE_RENEWAL_PERIOD, renewal_delta=LEASE_RENEWAL_DELTA)
x['leases'] = list(map(lambda _: _.serialize(**serialize), Lease.find_by_origin_ref(db, origin.origin_ref)))
response.append(x)
session.close()
return JSONr(response)
@@ -222,7 +236,8 @@ async def _leases(request: Request, origin: bool = False):
session = sessionmaker(bind=db)()
response = []
for lease in session.query(Lease).all():
x = lease.serialize()
serialize = dict(renewal_period=LEASE_RENEWAL_PERIOD, renewal_delta=LEASE_RENEWAL_DELTA)
x = lease.serialize(**serialize)
if origin:
lease_origin = session.query(Origin).filter(Origin.origin_ref == lease.origin_ref).first()
if lease_origin is not None:
@@ -249,13 +264,7 @@ async def _lease_delete(request: Request, lease_ref: str):
@app.get('/-/client-token', summary='* Client-Token', description='creates a new messenger token for this service instance')
async def _client_token():
cur_time = datetime.now(UTC)
default_instance = Instance.get_default_instance(db)
public_key = default_instance.get_public_key()
# todo: implemented request parameter to support different instances
jwt_encode_key = default_instance.get_jwt_encode_key()
exp_time = cur_time + default_instance.get_client_token_expire_delta()
exp_time = cur_time + CLIENT_TOKEN_EXPIRE_DELTA
payload = {
"jti": str(uuid4()),
@@ -268,7 +277,7 @@ async def _client_token():
"scope_ref_list": [ALLOTMENT_REF],
"fulfillment_class_ref_list": [],
"service_instance_configuration": {
"nls_service_instance_ref": default_instance.instance_ref,
"nls_service_instance_ref": INSTANCE_REF,
"svc_port_set_list": [
{
"idx": 0,
@@ -280,10 +289,10 @@ async def _client_token():
},
"service_instance_public_key_configuration": {
"service_instance_public_key_me": {
"mod": hex(public_key.raw().public_numbers().n)[2:],
"exp": int(public_key.raw().public_numbers().e),
"mod": hex(INSTANCE_KEY_PUB.raw().public_numbers().n)[2:],
"exp": int(INSTANCE_KEY_PUB.raw().public_numbers().e),
},
"service_instance_public_key_pem": public_key.pem().decode('utf-8'),
"service_instance_public_key_pem": INSTANCE_KEY_PUB.pem().decode('utf-8'),
"key_retention_mode": "LATEST_ONLY"
},
}
@@ -365,16 +374,13 @@ async def auth_v1_code(request: Request):
delta = relativedelta(minutes=15)
expires = cur_time + delta
default_site = Site.get_default_site(db)
jwt_encode_key = Instance.get_default_instance(db).get_jwt_encode_key()
payload = {
'iat': timegm(cur_time.timetuple()),
'exp': timegm(expires.timetuple()),
'challenge': j.get('code_challenge'),
'origin_ref': j.get('origin_ref'),
'key_ref': default_site.site_key,
'kid': default_site.site_key,
'key_ref': SITE_KEY_XID,
'kid': SITE_KEY_XID
}
auth_code = jws.sign(payload, key=jwt_encode_key, headers={'kid': payload.get('kid')}, algorithm=ALGORITHMS.RS256)
@@ -394,9 +400,6 @@ async def auth_v1_code(request: Request):
async def auth_v1_token(request: Request):
j, cur_time = json_loads((await request.body()).decode('utf-8')), datetime.now(UTC)
default_site, default_instance = Site.get_default_site(db), Instance.get_default_instance(db)
jwt_encode_key, jwt_decode_key = default_instance.get_jwt_encode_key(), default_instance.get_jwt_decode_key()
try:
payload = jwt.decode(token=j.get('auth_code'), key=jwt_decode_key, algorithms=ALGORITHMS.RS256)
except JWTError as e:
@@ -410,7 +413,7 @@ async def auth_v1_token(request: Request):
if payload.get('challenge') != challenge:
return JSONr(status_code=401, content={'status': 401, 'detail': 'expected challenge did not match verifier'})
access_expires_on = cur_time + default_instance.get_token_expire_delta()
access_expires_on = cur_time + TOKEN_EXPIRE_DELTA
new_payload = {
'iat': timegm(cur_time.timetuple()),
@@ -419,8 +422,8 @@ async def auth_v1_token(request: Request):
'aud': 'https://cls.nvidia.org',
'exp': timegm(access_expires_on.timetuple()),
'origin_ref': origin_ref,
'key_ref': default_site.site_key,
'kid': default_site.site_key,
'key_ref': SITE_KEY_XID,
'kid': SITE_KEY_XID,
}
auth_token = jwt.encode(new_payload, key=jwt_encode_key, headers={'kid': payload.get('kid')}, algorithm=ALGORITHMS.RS256)
@@ -437,13 +440,10 @@ async def auth_v1_token(request: Request):
# venv/lib/python3.9/site-packages/nls_services_lease/test/test_lease_multi_controller.py
@app.post('/leasing/v1/lessor', description='request multiple leases (borrow) for current origin')
async def leasing_v1_lessor(request: Request):
j, cur_time = json_loads((await request.body()).decode('utf-8')), datetime.now(UTC)
default_instance = Instance.get_default_instance(db)
jwt_decode_key = default_instance.get_jwt_decode_key()
j, token, cur_time = json_loads((await request.body()).decode('utf-8')), __get_token(request), datetime.now(UTC)
try:
token = __get_token(request, jwt_decode_key)
token = __get_token(request)
except JWTError:
return JSONr(status_code=401, content={'status': 401, 'detail': 'token is not valid'})
@@ -457,7 +457,7 @@ async def leasing_v1_lessor(request: Request):
# return JSONr(status_code=500, detail=f'no service instances found for scopes: ["{scope_ref}"]')
lease_ref = str(uuid4())
expires = cur_time + default_instance.get_lease_expire_delta()
expires = cur_time + LEASE_EXPIRE_DELTA
lease_result_list.append({
"ordinal": 0,
# https://docs.nvidia.com/license-system/latest/nvidia-license-system-user-guide/index.html
@@ -465,13 +465,13 @@ async def leasing_v1_lessor(request: Request):
"ref": lease_ref,
"created": cur_time.isoformat(),
"expires": expires.isoformat(),
"recommended_lease_renewal": default_instance.lease_renewal_period,
"recommended_lease_renewal": LEASE_RENEWAL_PERIOD,
"offline_lease": "true",
"license_type": "CONCURRENT_COUNTED_SINGLE"
}
})
data = Lease(instance_ref=default_instance.instance_ref, origin_ref=origin_ref, lease_ref=lease_ref, lease_created=cur_time, lease_expires=expires)
data = Lease(origin_ref=origin_ref, lease_ref=lease_ref, lease_created=cur_time, lease_expires=expires)
Lease.create_or_update(db, data)
response = {
@@ -488,14 +488,7 @@ async def leasing_v1_lessor(request: Request):
# venv/lib/python3.9/site-packages/nls_dal_service_instance_dls/schema/service_instance/V1_0_21__product_mapping.sql
@app.get('/leasing/v1/lessor/leases', description='get active leases for current origin')
async def leasing_v1_lessor_lease(request: Request):
cur_time = datetime.now(UTC)
jwt_decode_key = Instance.get_default_instance(db).get_jwt_decode_key()
try:
token = __get_token(request, jwt_decode_key)
except JWTError:
return JSONr(status_code=401, content={'status': 401, 'detail': 'token is not valid'})
token, cur_time = __get_token(request), datetime.now(UTC)
origin_ref = token.get('origin_ref')
@@ -515,15 +508,7 @@ async def leasing_v1_lessor_lease(request: Request):
# venv/lib/python3.9/site-packages/nls_core_lease/lease_single.py
@app.put('/leasing/v1/lease/{lease_ref}', description='renew a lease')
async def leasing_v1_lease_renew(request: Request, lease_ref: str):
cur_time = datetime.now(UTC)
default_instance = Instance.get_default_instance(db)
jwt_decode_key = default_instance.get_jwt_decode_key()
try:
token = __get_token(request, jwt_decode_key)
except JWTError:
return JSONr(status_code=401, content={'status': 401, 'detail': 'token is not valid'})
token, cur_time = __get_token(request), datetime.now(UTC)
origin_ref = token.get('origin_ref')
logger.info(f'> [ renew ]: {origin_ref}: renew {lease_ref}')
@@ -532,11 +517,11 @@ async def leasing_v1_lease_renew(request: Request, lease_ref: str):
if entity is None:
return JSONr(status_code=404, content={'status': 404, 'detail': 'requested lease not available'})
expires = cur_time + default_instance.get_lease_expire_delta()
expires = cur_time + LEASE_EXPIRE_DELTA
response = {
"lease_ref": lease_ref,
"expires": expires.isoformat(),
"recommended_lease_renewal": default_instance.lease_renewal_period,
"recommended_lease_renewal": LEASE_RENEWAL_PERIOD,
"offline_lease": True,
"prompts": None,
"sync_timestamp": cur_time.isoformat(),
@@ -550,14 +535,7 @@ async def leasing_v1_lease_renew(request: Request, lease_ref: str):
# venv/lib/python3.9/site-packages/nls_services_lease/test/test_lease_single_controller.py
@app.delete('/leasing/v1/lease/{lease_ref}', description='release (return) a lease')
async def leasing_v1_lease_delete(request: Request, lease_ref: str):
cur_time = datetime.now(UTC)
jwt_decode_key = Instance.get_default_instance(db).get_jwt_decode_key()
try:
token = __get_token(request, jwt_decode_key)
except JWTError:
return JSONr(status_code=401, content={'status': 401, 'detail': 'token is not valid'})
token, cur_time = __get_token(request), datetime.now(UTC)
origin_ref = token.get('origin_ref')
logger.info(f'> [ return ]: {origin_ref}: return {lease_ref}')
@@ -583,14 +561,7 @@ async def leasing_v1_lease_delete(request: Request, lease_ref: str):
# venv/lib/python3.9/site-packages/nls_services_lease/test/test_lease_multi_controller.py
@app.delete('/leasing/v1/lessor/leases', description='release all leases')
async def leasing_v1_lessor_lease_remove(request: Request):
cur_time = datetime.now(UTC)
jwt_decode_key = Instance.get_default_instance(db).get_jwt_decode_key()
try:
token = __get_token(request, jwt_decode_key)
except JWTError:
return JSONr(status_code=401, content={'status': 401, 'detail': 'token is not valid'})
token, cur_time = __get_token(request), datetime.now(UTC)
origin_ref = token.get('origin_ref')
@@ -612,8 +583,6 @@ async def leasing_v1_lessor_lease_remove(request: Request):
async def leasing_v1_lessor_shutdown(request: Request):
j, cur_time = json_loads((await request.body()).decode('utf-8')), datetime.now(UTC)
jwt_decode_key = Instance.get_default_instance(db).get_jwt_decode_key()
token = j.get('token')
token = jwt.decode(token=token, key=jwt_decode_key, algorithms=ALGORITHMS.RS256, options={'verify_aud': False})
origin_ref = token.get('origin_ref')

View File

@@ -1,143 +1,20 @@
import logging
from datetime import datetime, timedelta, timezone, UTC
from os import getenv as env
from os.path import join, dirname, isfile
from dateutil.relativedelta import relativedelta
from jose import jwk
from jose.constants import ALGORITHMS
from sqlalchemy import Column, VARCHAR, CHAR, ForeignKey, DATETIME, update, and_, inspect, text, BLOB, INT, FLOAT
from sqlalchemy import Column, VARCHAR, CHAR, ForeignKey, DATETIME, update, and_, inspect, text
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker, declarative_base, Session, relationship
from sqlalchemy.schema import CreateTable
from sqlalchemy.orm import sessionmaker, declarative_base
from util import DriverMatrix, PrivateKey, PublicKey, DriverMatrix
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
from util import NV
Base = declarative_base()
class Site(Base):
__tablename__ = "site"
INITIAL_SITE_KEY_XID = '10000000-0000-0000-0000-000000000000'
INITIAL_SITE_NAME = 'default-site'
site_key = Column(CHAR(length=36), primary_key=True, unique=True, index=True) # uuid4, SITE_KEY_XID
name = Column(VARCHAR(length=256), nullable=False)
def __str__(self):
return f'SITE_KEY_XID: {self.site_key}'
@staticmethod
def create_statement(engine: Engine):
return CreateTable(Site.__table__).compile(engine)
@staticmethod
def get_default_site(engine: Engine) -> "Site":
session = sessionmaker(bind=engine)()
entity = session.query(Site).filter(Site.site_key == Site.INITIAL_SITE_KEY_XID).first()
session.close()
return entity
class Instance(Base):
__tablename__ = "instance"
DEFAULT_INSTANCE_REF = '10000000-0000-0000-0000-000000000001'
DEFAULT_TOKEN_EXPIRE_DELTA = 86_400 # 1 day
DEFAULT_LEASE_EXPIRE_DELTA = 7_776_000 # 90 days
DEFAULT_LEASE_RENEWAL_PERIOD = 0.15
DEFAULT_CLIENT_TOKEN_EXPIRE_DELTA = 378_432_000 # 12 years
# 1 day = 86400 (min. in production setup, max 90 days), 1 hour = 3600
instance_ref = Column(CHAR(length=36), primary_key=True, unique=True, index=True) # uuid4, INSTANCE_REF
site_key = Column(CHAR(length=36), ForeignKey(Site.site_key, ondelete='CASCADE'), nullable=False, index=True) # uuid4
private_key = Column(BLOB(length=2048), nullable=False)
public_key = Column(BLOB(length=512), nullable=False)
token_expire_delta = Column(INT(), nullable=False, default=DEFAULT_TOKEN_EXPIRE_DELTA, comment='in seconds')
lease_expire_delta = Column(INT(), nullable=False, default=DEFAULT_LEASE_EXPIRE_DELTA, comment='in seconds')
lease_renewal_period = Column(FLOAT(precision=2), nullable=False, default=DEFAULT_LEASE_RENEWAL_PERIOD)
client_token_expire_delta = Column(INT(), nullable=False, default=DEFAULT_CLIENT_TOKEN_EXPIRE_DELTA, comment='in seconds')
__origin = relationship(Site, foreign_keys=[site_key])
def __str__(self):
return f'INSTANCE_REF: {self.instance_ref} (SITE_KEY_XID: {self.site_key})'
@staticmethod
def create_statement(engine: Engine):
return CreateTable(Instance.__table__).compile(engine)
@staticmethod
def create_or_update(engine: Engine, instance: "Instance"):
session = sessionmaker(bind=engine)()
entity = session.query(Instance).filter(Instance.instance_ref == instance.instance_ref).first()
if entity is None:
session.add(instance)
else:
x = dict(
site_key=instance.site_key,
private_key=instance.private_key,
public_key=instance.public_key,
token_expire_delta=instance.token_expire_delta,
lease_expire_delta=instance.lease_expire_delta,
lease_renewal_period=instance.lease_renewal_period,
client_token_expire_delta=instance.client_token_expire_delta,
)
session.execute(update(Instance).where(Instance.instance_ref == instance.instance_ref).values(**x))
session.commit()
session.flush()
session.close()
# todo: validate on startup that "lease_expire_delta" is between 1 day and 90 days
@staticmethod
def get_default_instance(engine: Engine) -> "Instance":
session = sessionmaker(bind=engine)()
site = Site.get_default_site(engine)
entity = session.query(Instance).filter(Instance.site_key == site.site_key).first()
session.close()
return entity
def get_token_expire_delta(self) -> "dateutil.relativedelta.relativedelta":
return relativedelta(seconds=self.token_expire_delta)
def get_lease_expire_delta(self) -> "dateutil.relativedelta.relativedelta":
return relativedelta(seconds=self.lease_expire_delta)
def get_lease_renewal_delta(self) -> "datetime.timedelta":
return timedelta(seconds=self.lease_expire_delta)
def get_client_token_expire_delta(self) -> "dateutil.relativedelta.relativedelta":
return relativedelta(seconds=self.client_token_expire_delta)
def __get_private_key(self) -> "PrivateKey":
return PrivateKey(self.private_key)
def get_public_key(self) -> "PublicKey":
return PublicKey(self.public_key)
def get_jwt_encode_key(self) -> "jose.jkw":
return jwk.construct(self.__get_private_key().pem().decode('utf-8'), algorithm=ALGORITHMS.RS256)
def get_jwt_decode_key(self) -> "jose.jwt":
return jwk.construct(self.get_public_key().pem().decode('utf-8'), algorithm=ALGORITHMS.RS256)
def get_private_key_str(self, encoding: str = 'utf-8') -> str:
return self.private_key.decode(encoding)
def get_public_key_str(self, encoding: str = 'utf-8') -> str:
return self.private_key.decode(encoding)
class Origin(Base):
__tablename__ = "origin"
origin_ref = Column(CHAR(length=36), primary_key=True, unique=True, index=True) # uuid4
# service_instance_xid = Column(CHAR(length=36), nullable=False, index=True) # uuid4 # not necessary, we only support one service_instance_xid ('INSTANCE_REF')
hostname = Column(VARCHAR(length=256), nullable=True)
guest_driver_version = Column(VARCHAR(length=10), nullable=True)
@@ -148,7 +25,7 @@ class Origin(Base):
return f'Origin(origin_ref={self.origin_ref}, hostname={self.hostname})'
def serialize(self) -> dict:
_ = DriverMatrix().find(self.guest_driver_version)
_ = NV().find(self.guest_driver_version)
return {
'origin_ref': self.origin_ref,
@@ -162,6 +39,7 @@ class Origin(Base):
@staticmethod
def create_statement(engine: Engine):
from sqlalchemy.schema import CreateTable
return CreateTable(Origin.__table__).compile(engine)
@staticmethod
@@ -207,24 +85,18 @@ class Origin(Base):
class Lease(Base):
__tablename__ = "lease"
instance_ref = Column(CHAR(length=36), ForeignKey(Instance.instance_ref, ondelete='CASCADE'), nullable=False, index=True) # uuid4
lease_ref = Column(CHAR(length=36), primary_key=True, nullable=False, index=True) # uuid4
origin_ref = Column(CHAR(length=36), ForeignKey(Origin.origin_ref, ondelete='CASCADE'), nullable=False, index=True) # uuid4
# scope_ref = Column(CHAR(length=36), nullable=False, index=True) # uuid4 # not necessary, we only support one scope_ref ('ALLOTMENT_REF')
lease_created = Column(DATETIME(), nullable=False)
lease_expires = Column(DATETIME(), nullable=False)
lease_updated = Column(DATETIME(), nullable=False)
__instance = relationship(Instance, foreign_keys=[instance_ref])
__origin = relationship(Origin, foreign_keys=[origin_ref])
def __repr__(self):
return f'Lease(origin_ref={self.origin_ref}, lease_ref={self.lease_ref}, expires={self.lease_expires})'
def serialize(self) -> dict:
renewal_period = self.__instance.lease_renewal_period
renewal_delta = self.__instance.get_lease_renewal_delta
def serialize(self, renewal_period: float, renewal_delta: timedelta) -> dict:
lease_renewal = int(Lease.calculate_renewal(renewal_period, renewal_delta).total_seconds())
lease_renewal = self.lease_updated + relativedelta(seconds=lease_renewal)
@@ -240,6 +112,7 @@ class Lease(Base):
@staticmethod
def create_statement(engine: Engine):
from sqlalchemy.schema import CreateTable
return CreateTable(Lease.__table__).compile(engine)
@staticmethod
@@ -333,104 +206,38 @@ class Lease(Base):
return renew
def init_default_site(session: Session):
private_key = PrivateKey.generate()
public_key = private_key.public_key()
site = Site(
site_key=Site.INITIAL_SITE_KEY_XID,
name=Site.INITIAL_SITE_NAME
)
session.add(site)
session.commit()
instance = Instance(
instance_ref=Instance.DEFAULT_INSTANCE_REF,
site_key=site.site_key,
private_key=private_key.pem(),
public_key=public_key.pem(),
)
session.add(instance)
session.commit()
def init(engine: Engine):
tables = [Site, Instance, Origin, Lease]
tables = [Origin, Lease]
db = inspect(engine)
session = sessionmaker(bind=engine)()
for table in tables:
exists = db.dialect.has_table(engine.connect(), table.__tablename__)
logger.info(f'> Table "{table.__tablename__:<16}" exists: {exists}')
if not exists:
if not db.dialect.has_table(engine.connect(), table.__tablename__):
session.execute(text(str(table.create_statement(engine))))
session.commit()
# create default site
cnt = session.query(Site).count()
if cnt == 0:
init_default_site(session)
session.flush()
session.close()
def migrate(engine: Engine):
db = inspect(engine)
# todo: add update guide to use 1.LATEST to 2.0
def upgrade_1_x_to_2_0():
site = Site.get_default_site(engine)
logger.info(site)
instance = Instance.get_default_instance(engine)
logger.info(instance)
def upgrade_1_0_to_1_1():
x = db.dialect.get_columns(engine.connect(), Lease.__tablename__)
x = next(_ for _ in x if _['name'] == 'origin_ref')
if x['primary_key'] > 0:
print('Found old database schema with "origin_ref" as primary-key in "lease" table. Dropping table!')
print(' Your leases are recreated on next renewal!')
print(' If an error message appears on the client, you can ignore it.')
Lease.__table__.drop(bind=engine)
init(engine)
# SITE_KEY_XID
if site_key := env('SITE_KEY_XID', None) is not None:
site.site_key = str(site_key)
# def upgrade_1_2_to_1_3():
# x = db.dialect.get_columns(engine.connect(), Lease.__tablename__)
# x = next((_ for _ in x if _['name'] == 'scope_ref'), None)
# if x is None:
# Lease.scope_ref.compile()
# column_name = Lease.scope_ref.name
# column_type = Lease.scope_ref.type.compile(engine.dialect)
# engine.execute(f'ALTER TABLE "{Lease.__tablename__}" ADD COLUMN "{column_name}" {column_type}')
# INSTANCE_REF
if instance_ref := env('INSTANCE_REF', None) is not None:
instance.instance_ref = str(instance_ref)
# ALLOTMENT_REF
if allotment_ref := env('ALLOTMENT_REF', None) is not None:
pass # todo
# INSTANCE_KEY_RSA, INSTANCE_KEY_PUB
default_instance_private_key_path = str(join(dirname(__file__), 'cert/instance.private.pem'))
instance_private_key = env('INSTANCE_KEY_RSA', None)
if instance_private_key is not None:
instance.private_key = PrivateKey(instance_private_key.encode('utf-8'))
elif isfile(default_instance_private_key_path):
instance.private_key = PrivateKey.from_file(default_instance_private_key_path)
default_instance_public_key_path = str(join(dirname(__file__), 'cert/instance.public.pem'))
instance_public_key = env('INSTANCE_KEY_PUB', None)
if instance_public_key is not None:
instance.public_key = PublicKey(instance_public_key.encode('utf-8'))
elif isfile(default_instance_public_key_path):
instance.public_key = PublicKey.from_file(default_instance_public_key_path)
# TOKEN_EXPIRE_DELTA
token_expire_delta = env('TOKEN_EXPIRE_DAYS', None)
if token_expire_delta not in (None, 0):
instance.token_expire_delta = token_expire_delta * 86_400
token_expire_delta = env('TOKEN_EXPIRE_HOURS', None)
if token_expire_delta not in (None, 0):
instance.token_expire_delta = token_expire_delta * 3_600
# LEASE_EXPIRE_DELTA, LEASE_RENEWAL_DELTA
lease_expire_delta = env('LEASE_EXPIRE_DAYS', None)
if lease_expire_delta not in (None, 0):
instance.lease_expire_delta = lease_expire_delta * 86_400
lease_expire_delta = env('LEASE_EXPIRE_HOURS', None)
if lease_expire_delta not in (None, 0):
instance.lease_expire_delta = lease_expire_delta * 3_600
# LEASE_RENEWAL_PERIOD
lease_renewal_period = env('LEASE_RENEWAL_PERIOD', None)
if lease_renewal_period is not None:
instance.lease_renewal_period = lease_renewal_period
# todo: update site, instance
upgrade_1_x_to_2_0()
upgrade_1_0_to_1_1()
# upgrade_1_2_to_1_3()

View File

@@ -1,5 +1,4 @@
import logging
from json import load as json_load
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey, generate_private_key
@@ -8,14 +7,6 @@ from cryptography.hazmat.primitives.serialization import load_pem_private_key, l
logging.basicConfig()
def load_file(filename: str) -> bytes:
log = logging.getLogger(f'{__name__}')
log.debug(f'Loading contents of file "{filename}')
with open(filename, 'rb') as file:
content = file.read()
return content
class PrivateKey:
def __init__(self, data: bytes):
@@ -85,32 +76,37 @@ class PublicKey:
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
def load_file(filename: str) -> bytes:
log = logging.getLogger(f'{__name__}')
log.debug(f'Loading contents of file "{filename}')
with open(filename, 'rb') as file:
content = file.read()
return content
class DriverMatrix:
class NV:
__DRIVER_MATRIX_FILENAME = 'static/driver_matrix.json'
__DRIVER_MATRIX: None | dict = None # https://docs.nvidia.com/grid/ => "Driver Versions"
def __init__(self):
self.log = logging.getLogger(self.__class__.__name__)
if DriverMatrix.__DRIVER_MATRIX is None:
self.__load()
def __load(self):
if NV.__DRIVER_MATRIX is None:
from json import load as json_load
try:
file = open(DriverMatrix.__DRIVER_MATRIX_FILENAME)
DriverMatrix.__DRIVER_MATRIX = json_load(file)
file = open(NV.__DRIVER_MATRIX_FILENAME)
NV.__DRIVER_MATRIX = json_load(file)
file.close()
self.log.debug(f'Successfully loaded "{DriverMatrix.__DRIVER_MATRIX_FILENAME}".')
self.log.debug(f'Successfully loaded "{NV.__DRIVER_MATRIX_FILENAME}".')
except Exception as e:
DriverMatrix.__DRIVER_MATRIX = {} # init empty dict to not try open file everytime, just when restarting app
# self.log.warning(f'Failed to load "{DriverMatrix.__DRIVER_MATRIX_FILENAME}": {e}')
NV.__DRIVER_MATRIX = {} # init empty dict to not try open file everytime, just when restarting app
# self.log.warning(f'Failed to load "{NV.__DRIVER_MATRIX_FILENAME}": {e}')
@staticmethod
def find(version: str) -> dict | None:
if DriverMatrix.__DRIVER_MATRIX is None:
if NV.__DRIVER_MATRIX is None:
return None
for idx, (key, branch) in enumerate(DriverMatrix.__DRIVER_MATRIX.items()):
for idx, (key, branch) in enumerate(NV.__DRIVER_MATRIX.items()):
for release in branch.get('$releases'):
linux_driver = release.get('Linux Driver')
windows_driver = release.get('Windows Driver')

View File

@@ -1,8 +1,8 @@
fastapi==0.115.12
uvicorn[standard]==0.34.1
uvicorn[standard]==0.34.0
python-jose[cryptography]==3.4.0
cryptography==44.0.2
python-dateutil==2.9.0
sqlalchemy==2.0.40
markdown==3.8
markdown==3.7
python-dotenv==1.1.0

View File

@@ -6,7 +6,7 @@ logger.setLevel(logging.INFO)
URL = 'https://docs.nvidia.com/vgpu/index.html'
BRANCH_STATUS_KEY = 'vGPU Branch Status'
BRANCH_STATUS_KEY, SOFTWARE_BRANCH_KEY, = 'vGPU Branch Status', 'vGPU Software Branch'
VGPU_KEY, GRID_KEY, DRIVER_BRANCH_KEY = 'vGPU Software', 'vGPU Software', 'Driver Branch'
LINUX_VGPU_MANAGER_KEY, LINUX_DRIVER_KEY = 'Linux vGPU Manager', 'Linux Driver'
WINDOWS_VGPU_MANAGER_KEY, WINDOWS_DRIVER_KEY = 'Windows vGPU Manager', 'Windows Driver'
@@ -26,15 +26,12 @@ def __driver_versions(html: 'BeautifulSoup'):
# find wrapper for "DriverVersions" and find tables
data = html.find('div', {'id': 'driver-versions'})
items = data.find_all('bsp-accordion', {'class': 'Accordion-items-item'})
items = data.findAll('bsp-accordion', {'class': 'Accordion-items-item'})
for item in items:
software_branch = item.find('div', {'class': 'Accordion-items-item-title'}).text.strip()
software_branch = software_branch.replace(' Releases', '')
matrix_key = software_branch.lower()
branch_status = item.find('a', href=True, string='Branch status')
branch_status = branch_status.next_sibling.replace(':', '').strip()
# driver version info from table-heads (ths) and table-rows (trs)
table = item.find('table')
ths, trs = table.find_all('th'), table.find_all('tr')
@@ -45,20 +42,48 @@ def __driver_versions(html: 'BeautifulSoup'):
continue
# create dict with table-heads as key and cell content as value
x = {headers[i]: __strip(cell.text) for i, cell in enumerate(tds)}
x.setdefault(BRANCH_STATUS_KEY, branch_status)
releases.append(x)
# add to matrix
MATRIX.update({matrix_key: {JSON_RELEASES_KEY: releases}})
def __release_branches(html: 'BeautifulSoup'):
# find wrapper for "AllReleaseBranches" and find table
data = html.find('div', {'id': 'all-release-branches'})
table = data.find('table')
# branch releases info from table-heads (ths) and table-rows (trs)
ths, trs = table.find_all('th'), table.find_all('tr')
headers = [header.text.strip() for header in ths]
for trs in trs:
tds = trs.find_all('td')
if len(tds) == 0: # skip empty
continue
# create dict with table-heads as key and cell content as value
x = {headers[i]: cell.text.strip() for i, cell in enumerate(tds)}
# get matrix_key
software_branch = x.get(SOFTWARE_BRANCH_KEY)
matrix_key = software_branch.lower()
# add to matrix
MATRIX.update({matrix_key: MATRIX.get(matrix_key) | x})
def __debug():
# print table head
s = f'{VGPU_KEY:^13} | {LINUX_VGPU_MANAGER_KEY:^21} | {LINUX_DRIVER_KEY:^21} | {WINDOWS_VGPU_MANAGER_KEY:^21} | {WINDOWS_DRIVER_KEY:^21} | {RELEASE_DATE_KEY:>21} | {BRANCH_STATUS_KEY:^21}'
s = f'{SOFTWARE_BRANCH_KEY:^21} | {BRANCH_STATUS_KEY:^21} | {VGPU_KEY:^13} | {LINUX_VGPU_MANAGER_KEY:^21} | {LINUX_DRIVER_KEY:^21} | {WINDOWS_VGPU_MANAGER_KEY:^21} | {WINDOWS_DRIVER_KEY:^21} | {RELEASE_DATE_KEY:>21} | {EOL_KEY:>21}'
print(s)
# iterate over dict & format some variables to not overload table
for idx, (key, branch) in enumerate(MATRIX.items()):
branch_status = branch.get(BRANCH_STATUS_KEY)
branch_status = branch_status.replace('Branch ', '')
branch_status = branch_status.replace('Long-Term Support', 'LTS')
branch_status = branch_status.replace('Production', 'Prod.')
software_branch = branch.get(SOFTWARE_BRANCH_KEY).replace('NVIDIA ', '')
for release in branch.get(JSON_RELEASES_KEY):
version = release.get(VGPU_KEY, release.get(GRID_KEY, ''))
linux_manager = release.get(LINUX_VGPU_MANAGER_KEY, release.get(ALT_VGPU_MANAGER_KEY, ''))
@@ -67,25 +92,13 @@ def __debug():
windows_driver = release.get(WINDOWS_DRIVER_KEY)
release_date = release.get(RELEASE_DATE_KEY)
is_latest = release.get(VGPU_KEY) == branch.get(LATEST_KEY)
branch_status = __parse_branch_status(release.get(BRANCH_STATUS_KEY, ''))
version = f'{version} *' if is_latest else version
s = f'{version:<13} | {linux_manager:<21} | {linux_driver:<21} | {windows_manager:<21} | {windows_driver:<21} | {release_date:>21} | {branch_status:^21}'
eol = branch.get(EOL_KEY) if is_latest else ''
s = f'{software_branch:^21} | {branch_status:^21} | {version:<13} | {linux_manager:<21} | {linux_driver:<21} | {windows_manager:<21} | {windows_driver:<21} | {release_date:>21} | {eol:>21}'
print(s)
def __parse_branch_status(string: str) -> str:
string = string.replace('Production Branch', 'Prod. -')
string = string.replace('Long-Term Support Branch', 'LTS -')
string = string.replace('supported until', '')
string = string.replace('EOL since', 'EOL - ')
string = string.replace('EOL from', 'EOL -')
return string
def __dump(filename: str):
import json
@@ -115,6 +128,7 @@ if __name__ == '__main__':
# build matrix
__driver_versions(soup)
__release_branches(soup)
# debug output
__debug()

View File

@@ -3,13 +3,12 @@ from base64 import b64encode as b64enc
from calendar import timegm
from datetime import datetime, UTC
from hashlib import sha256
from os import getenv as env
from os.path import dirname, join
from uuid import uuid4, UUID
from dateutil.relativedelta import relativedelta
from jose import jwt
from jose import jwt, jwk
from jose.constants import ALGORITHMS
from sqlalchemy import create_engine
from starlette.testclient import TestClient
# add relative path to use packages as they were in the app/ dir
@@ -17,23 +16,20 @@ sys.path.append('../')
sys.path.append('../app')
from app import main
from orm import init as db_init, migrate, Site, Instance
from util import PrivateKey, PublicKey
client = TestClient(main.app)
ORIGIN_REF, ALLOTMENT_REF, SECRET = str(uuid4()), '20000000-0000-0000-0000-000000000001', 'HelloWorld'
# fastapi setup
client = TestClient(main.app)
# INSTANCE_KEY_RSA = generate_key()
# INSTANCE_KEY_PUB = INSTANCE_KEY_RSA.public_key()
# database setup
db = create_engine(str(env('DATABASE', 'sqlite:///db.sqlite')))
db_init(db), migrate(db)
INSTANCE_KEY_RSA = PrivateKey.from_file(str(join(dirname(__file__), '../app/cert/instance.private.pem')))
INSTANCE_KEY_PUB = PublicKey.from_file(str(join(dirname(__file__), '../app/cert/instance.public.pem')))
# test vars
DEFAULT_SITE, DEFAULT_INSTANCE = Site.get_default_site(db), Instance.get_default_instance(db)
SITE_KEY = DEFAULT_SITE.site_key
jwt_encode_key, jwt_decode_key = DEFAULT_INSTANCE.get_jwt_encode_key(), DEFAULT_INSTANCE.get_jwt_decode_key()
jwt_encode_key = jwk.construct(INSTANCE_KEY_RSA.pem(), algorithm=ALGORITHMS.RS256)
jwt_decode_key = jwk.construct(INSTANCE_KEY_PUB.pem(), algorithm=ALGORITHMS.RS256)
def __bearer_token(origin_ref: str) -> str:
@@ -42,12 +38,6 @@ def __bearer_token(origin_ref: str) -> str:
return token
def test_initial_default_site_and_instance():
default_site, default_instance = Site.get_default_site(db), Instance.get_default_instance(db)
assert default_site.site_key == Site.INITIAL_SITE_KEY_XID
assert default_instance.instance_ref == Instance.DEFAULT_INSTANCE_REF
def test_index():
response = client.get('/')
assert response.status_code == 200