♻️ Refactor backend, settings, DB sessions, types, configs, plugins (#158)
* ♻️ Refactor backend, update DB session handling * ✨ Add mypy config and plugins * ➕ Use Python-jose instead of PyJWT as it has some extra functionalities and features * ✨ Add/update scripts for test, lint, format * 🔧 Update lint and format configs * 🎨 Update import format, comments, and types * 🎨 Add types to config * ✨ Add types for all the code, and small fixes * 🎨 Use global imports to simplify exploring with Jupyter * ♻️ Import schemas and models, instead of each class * 🚚 Rename db_session to db for simplicity * 📌 Update dependencies installation for testing
This commit is contained in:
committed by
GitHub
parent
4b80bdfdce
commit
eed33d276d
3
{{cookiecutter.project_slug}}/backend/app/.flake8
Normal file
3
{{cookiecutter.project_slug}}/backend/app/.flake8
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
[flake8]
|
||||||
|
max-line-length = 88
|
||||||
|
exclude = .git,__pycache__,__init__.py,.mypy_cache,.pytest_cache
|
||||||
3
{{cookiecutter.project_slug}}/backend/app/.gitignore
vendored
Normal file
3
{{cookiecutter.project_slug}}/backend/app/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
.mypy_cache
|
||||||
|
.coverage
|
||||||
|
htmlcov
|
||||||
@@ -1,24 +1,21 @@
|
|||||||
from typing import List
|
from typing import Any, List
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app import crud
|
from app import crud, models, schemas
|
||||||
from app.api.utils.db import get_db
|
from app.api import deps
|
||||||
from app.api.utils.security import get_current_active_user
|
|
||||||
from app.models.user import User as DBUser
|
|
||||||
from app.schemas.item import Item, ItemCreate, ItemUpdate
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=List[Item])
|
@router.get("/", response_model=List[schemas.Item])
|
||||||
def read_items(
|
def read_items(
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(deps.get_db),
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
current_user: DBUser = Depends(get_current_active_user),
|
current_user: models.User = Depends(deps.get_current_active_user),
|
||||||
):
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Retrieve items.
|
Retrieve items.
|
||||||
"""
|
"""
|
||||||
@@ -26,58 +23,56 @@ def read_items(
|
|||||||
items = crud.item.get_multi(db, skip=skip, limit=limit)
|
items = crud.item.get_multi(db, skip=skip, limit=limit)
|
||||||
else:
|
else:
|
||||||
items = crud.item.get_multi_by_owner(
|
items = crud.item.get_multi_by_owner(
|
||||||
db_session=db, owner_id=current_user.id, skip=skip, limit=limit
|
db=db, owner_id=current_user.id, skip=skip, limit=limit
|
||||||
)
|
)
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
@router.post("/", response_model=Item)
|
@router.post("/", response_model=schemas.Item)
|
||||||
def create_item(
|
def create_item(
|
||||||
*,
|
*,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(deps.get_db),
|
||||||
item_in: ItemCreate,
|
item_in: schemas.ItemCreate,
|
||||||
current_user: DBUser = Depends(get_current_active_user),
|
current_user: models.User = Depends(deps.get_current_active_user),
|
||||||
):
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Create new item.
|
Create new item.
|
||||||
"""
|
"""
|
||||||
item = crud.item.create_with_owner(
|
item = crud.item.create_with_owner(db=db, obj_in=item_in, owner_id=current_user.id)
|
||||||
db_session=db, obj_in=item_in, owner_id=current_user.id
|
|
||||||
)
|
|
||||||
return item
|
return item
|
||||||
|
|
||||||
|
|
||||||
@router.put("/{id}", response_model=Item)
|
@router.put("/{id}", response_model=schemas.Item)
|
||||||
def update_item(
|
def update_item(
|
||||||
*,
|
*,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(deps.get_db),
|
||||||
id: int,
|
id: int,
|
||||||
item_in: ItemUpdate,
|
item_in: schemas.ItemUpdate,
|
||||||
current_user: DBUser = Depends(get_current_active_user),
|
current_user: models.User = Depends(deps.get_current_active_user),
|
||||||
):
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Update an item.
|
Update an item.
|
||||||
"""
|
"""
|
||||||
item = crud.item.get(db_session=db, id=id)
|
item = crud.item.get(db=db, id=id)
|
||||||
if not item:
|
if not item:
|
||||||
raise HTTPException(status_code=404, detail="Item not found")
|
raise HTTPException(status_code=404, detail="Item not found")
|
||||||
if not crud.user.is_superuser(current_user) and (item.owner_id != current_user.id):
|
if not crud.user.is_superuser(current_user) and (item.owner_id != current_user.id):
|
||||||
raise HTTPException(status_code=400, detail="Not enough permissions")
|
raise HTTPException(status_code=400, detail="Not enough permissions")
|
||||||
item = crud.item.update(db_session=db, db_obj=item, obj_in=item_in)
|
item = crud.item.update(db=db, db_obj=item, obj_in=item_in)
|
||||||
return item
|
return item
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{id}", response_model=Item)
|
@router.get("/{id}", response_model=schemas.Item)
|
||||||
def read_item(
|
def read_item(
|
||||||
*,
|
*,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(deps.get_db),
|
||||||
id: int,
|
id: int,
|
||||||
current_user: DBUser = Depends(get_current_active_user),
|
current_user: models.User = Depends(deps.get_current_active_user),
|
||||||
):
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Get item by ID.
|
Get item by ID.
|
||||||
"""
|
"""
|
||||||
item = crud.item.get(db_session=db, id=id)
|
item = crud.item.get(db=db, id=id)
|
||||||
if not item:
|
if not item:
|
||||||
raise HTTPException(status_code=404, detail="Item not found")
|
raise HTTPException(status_code=404, detail="Item not found")
|
||||||
if not crud.user.is_superuser(current_user) and (item.owner_id != current_user.id):
|
if not crud.user.is_superuser(current_user) and (item.owner_id != current_user.id):
|
||||||
@@ -85,20 +80,20 @@ def read_item(
|
|||||||
return item
|
return item
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{id}", response_model=Item)
|
@router.delete("/{id}", response_model=schemas.Item)
|
||||||
def delete_item(
|
def delete_item(
|
||||||
*,
|
*,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(deps.get_db),
|
||||||
id: int,
|
id: int,
|
||||||
current_user: DBUser = Depends(get_current_active_user),
|
current_user: models.User = Depends(deps.get_current_active_user),
|
||||||
):
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Delete an item.
|
Delete an item.
|
||||||
"""
|
"""
|
||||||
item = crud.item.get(db_session=db, id=id)
|
item = crud.item.get(db=db, id=id)
|
||||||
if not item:
|
if not item:
|
||||||
raise HTTPException(status_code=404, detail="Item not found")
|
raise HTTPException(status_code=404, detail="Item not found")
|
||||||
if not crud.user.is_superuser(current_user) and (item.owner_id != current_user.id):
|
if not crud.user.is_superuser(current_user) and (item.owner_id != current_user.id):
|
||||||
raise HTTPException(status_code=400, detail="Not enough permissions")
|
raise HTTPException(status_code=400, detail="Not enough permissions")
|
||||||
item = crud.item.remove(db_session=db, id=id)
|
item = crud.item.remove(db=db, id=id)
|
||||||
return item
|
return item
|
||||||
|
|||||||
@@ -1,19 +1,15 @@
|
|||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app import crud
|
from app import crud, models, schemas
|
||||||
from app.api.utils.db import get_db
|
from app.api import deps
|
||||||
from app.api.utils.security import get_current_user
|
from app.core import security
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.jwt import create_access_token
|
|
||||||
from app.core.security import get_password_hash
|
from app.core.security import get_password_hash
|
||||||
from app.models.user import User as DBUser
|
|
||||||
from app.schemas.msg import Msg
|
|
||||||
from app.schemas.token import Token
|
|
||||||
from app.schemas.user import User
|
|
||||||
from app.utils import (
|
from app.utils import (
|
||||||
generate_password_reset_token,
|
generate_password_reset_token,
|
||||||
send_reset_password_email,
|
send_reset_password_email,
|
||||||
@@ -23,10 +19,10 @@ from app.utils import (
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.post("/login/access-token", response_model=Token)
|
@router.post("/login/access-token", response_model=schemas.Token)
|
||||||
def login_access_token(
|
def login_access_token(
|
||||||
db: Session = Depends(get_db), form_data: OAuth2PasswordRequestForm = Depends()
|
db: Session = Depends(deps.get_db), form_data: OAuth2PasswordRequestForm = Depends()
|
||||||
):
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
OAuth2 compatible token login, get an access token for future requests
|
OAuth2 compatible token login, get an access token for future requests
|
||||||
"""
|
"""
|
||||||
@@ -39,23 +35,23 @@ def login_access_token(
|
|||||||
raise HTTPException(status_code=400, detail="Inactive user")
|
raise HTTPException(status_code=400, detail="Inactive user")
|
||||||
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||||
return {
|
return {
|
||||||
"access_token": create_access_token(
|
"access_token": security.create_access_token(
|
||||||
data={"user_id": user.id}, expires_delta=access_token_expires
|
user.id, expires_delta=access_token_expires
|
||||||
),
|
),
|
||||||
"token_type": "bearer",
|
"token_type": "bearer",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/login/test-token", response_model=User)
|
@router.post("/login/test-token", response_model=schemas.User)
|
||||||
def test_token(current_user: DBUser = Depends(get_current_user)):
|
def test_token(current_user: models.User = Depends(deps.get_current_user)) -> Any:
|
||||||
"""
|
"""
|
||||||
Test access token
|
Test access token
|
||||||
"""
|
"""
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
@router.post("/password-recovery/{email}", response_model=Msg)
|
@router.post("/password-recovery/{email}", response_model=schemas.Msg)
|
||||||
def recover_password(email: str, db: Session = Depends(get_db)):
|
def recover_password(email: str, db: Session = Depends(deps.get_db)) -> Any:
|
||||||
"""
|
"""
|
||||||
Password Recovery
|
Password Recovery
|
||||||
"""
|
"""
|
||||||
@@ -73,10 +69,12 @@ def recover_password(email: str, db: Session = Depends(get_db)):
|
|||||||
return {"msg": "Password recovery email sent"}
|
return {"msg": "Password recovery email sent"}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/reset-password/", response_model=Msg)
|
@router.post("/reset-password/", response_model=schemas.Msg)
|
||||||
def reset_password(
|
def reset_password(
|
||||||
token: str = Body(...), new_password: str = Body(...), db: Session = Depends(get_db)
|
token: str = Body(...),
|
||||||
):
|
new_password: str = Body(...),
|
||||||
|
db: Session = Depends(deps.get_db),
|
||||||
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Reset password
|
Reset password
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,28 +1,25 @@
|
|||||||
from typing import List
|
from typing import Any, List
|
||||||
|
|
||||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from pydantic.networks import EmailStr
|
from pydantic.networks import EmailStr
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app import crud
|
from app import crud, models, schemas
|
||||||
from app.api.utils.db import get_db
|
from app.api import deps
|
||||||
from app.api.utils.security import get_current_active_superuser, get_current_active_user
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.models.user import User as DBUser
|
|
||||||
from app.schemas.user import User, UserCreate, UserUpdate
|
|
||||||
from app.utils import send_new_account_email
|
from app.utils import send_new_account_email
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=List[User])
|
@router.get("/", response_model=List[schemas.User])
|
||||||
def read_users(
|
def read_users(
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(deps.get_db),
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
current_user: DBUser = Depends(get_current_active_superuser),
|
current_user: models.User = Depends(deps.get_current_active_superuser),
|
||||||
):
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Retrieve users.
|
Retrieve users.
|
||||||
"""
|
"""
|
||||||
@@ -30,13 +27,13 @@ def read_users(
|
|||||||
return users
|
return users
|
||||||
|
|
||||||
|
|
||||||
@router.post("/", response_model=User)
|
@router.post("/", response_model=schemas.User)
|
||||||
def create_user(
|
def create_user(
|
||||||
*,
|
*,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(deps.get_db),
|
||||||
user_in: UserCreate,
|
user_in: schemas.UserCreate,
|
||||||
current_user: DBUser = Depends(get_current_active_superuser),
|
current_user: models.User = Depends(deps.get_current_active_superuser),
|
||||||
):
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Create new user.
|
Create new user.
|
||||||
"""
|
"""
|
||||||
@@ -54,20 +51,20 @@ def create_user(
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
@router.put("/me", response_model=User)
|
@router.put("/me", response_model=schemas.User)
|
||||||
def update_user_me(
|
def update_user_me(
|
||||||
*,
|
*,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(deps.get_db),
|
||||||
password: str = Body(None),
|
password: str = Body(None),
|
||||||
full_name: str = Body(None),
|
full_name: str = Body(None),
|
||||||
email: EmailStr = Body(None),
|
email: EmailStr = Body(None),
|
||||||
current_user: DBUser = Depends(get_current_active_user),
|
current_user: models.User = Depends(deps.get_current_active_user),
|
||||||
):
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Update own user.
|
Update own user.
|
||||||
"""
|
"""
|
||||||
current_user_data = jsonable_encoder(current_user)
|
current_user_data = jsonable_encoder(current_user)
|
||||||
user_in = UserUpdate(**current_user_data)
|
user_in = schemas.UserUpdate(**current_user_data)
|
||||||
if password is not None:
|
if password is not None:
|
||||||
user_in.password = password
|
user_in.password = password
|
||||||
if full_name is not None:
|
if full_name is not None:
|
||||||
@@ -78,25 +75,25 @@ def update_user_me(
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
@router.get("/me", response_model=User)
|
@router.get("/me", response_model=schemas.User)
|
||||||
def read_user_me(
|
def read_user_me(
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(deps.get_db),
|
||||||
current_user: DBUser = Depends(get_current_active_user),
|
current_user: models.User = Depends(deps.get_current_active_user),
|
||||||
):
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Get current user.
|
Get current user.
|
||||||
"""
|
"""
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
@router.post("/open", response_model=User)
|
@router.post("/open", response_model=schemas.User)
|
||||||
def create_user_open(
|
def create_user_open(
|
||||||
*,
|
*,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(deps.get_db),
|
||||||
password: str = Body(...),
|
password: str = Body(...),
|
||||||
email: EmailStr = Body(...),
|
email: EmailStr = Body(...),
|
||||||
full_name: str = Body(None),
|
full_name: str = Body(None),
|
||||||
):
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Create new user without the need to be logged in.
|
Create new user without the need to be logged in.
|
||||||
"""
|
"""
|
||||||
@@ -111,17 +108,17 @@ def create_user_open(
|
|||||||
status_code=400,
|
status_code=400,
|
||||||
detail="The user with this username already exists in the system",
|
detail="The user with this username already exists in the system",
|
||||||
)
|
)
|
||||||
user_in = UserCreate(password=password, email=email, full_name=full_name)
|
user_in = schemas.UserCreate(password=password, email=email, full_name=full_name)
|
||||||
user = crud.user.create(db, obj_in=user_in)
|
user = crud.user.create(db, obj_in=user_in)
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{user_id}", response_model=User)
|
@router.get("/{user_id}", response_model=schemas.User)
|
||||||
def read_user_by_id(
|
def read_user_by_id(
|
||||||
user_id: int,
|
user_id: int,
|
||||||
current_user: DBUser = Depends(get_current_active_user),
|
current_user: models.User = Depends(deps.get_current_active_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(deps.get_db),
|
||||||
):
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Get a specific user by id.
|
Get a specific user by id.
|
||||||
"""
|
"""
|
||||||
@@ -135,14 +132,14 @@ def read_user_by_id(
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
@router.put("/{user_id}", response_model=User)
|
@router.put("/{user_id}", response_model=schemas.User)
|
||||||
def update_user(
|
def update_user(
|
||||||
*,
|
*,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(deps.get_db),
|
||||||
user_id: int,
|
user_id: int,
|
||||||
user_in: UserUpdate,
|
user_in: schemas.UserUpdate,
|
||||||
current_user: DBUser = Depends(get_current_active_superuser),
|
current_user: models.User = Depends(deps.get_current_active_superuser),
|
||||||
):
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Update a user.
|
Update a user.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,18 +1,21 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
from pydantic.networks import EmailStr
|
from pydantic.networks import EmailStr
|
||||||
|
|
||||||
from app.api.utils.security import get_current_active_superuser
|
from app import models, schemas
|
||||||
|
from app.api import deps
|
||||||
from app.core.celery_app import celery_app
|
from app.core.celery_app import celery_app
|
||||||
from app.schemas.msg import Msg
|
|
||||||
from app.schemas.user import User # noqa: F401
|
|
||||||
from app.models.user import User as DBUser
|
|
||||||
from app.utils import send_test_email
|
from app.utils import send_test_email
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.post("/test-celery/", response_model=Msg, status_code=201)
|
@router.post("/test-celery/", response_model=schemas.Msg, status_code=201)
|
||||||
def test_celery(msg: Msg, current_user: DBUser = Depends(get_current_active_superuser)):
|
def test_celery(
|
||||||
|
msg: schemas.Msg,
|
||||||
|
current_user: models.User = Depends(deps.get_current_active_superuser),
|
||||||
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Test Celery worker.
|
Test Celery worker.
|
||||||
"""
|
"""
|
||||||
@@ -20,10 +23,11 @@ def test_celery(msg: Msg, current_user: DBUser = Depends(get_current_active_supe
|
|||||||
return {"msg": "Word received"}
|
return {"msg": "Word received"}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/test-email/", response_model=Msg, status_code=201)
|
@router.post("/test-email/", response_model=schemas.Msg, status_code=201)
|
||||||
def test_email(
|
def test_email(
|
||||||
email_to: EmailStr, current_user: DBUser = Depends(get_current_active_superuser)
|
email_to: EmailStr,
|
||||||
):
|
current_user: models.User = Depends(deps.get_current_active_superuser),
|
||||||
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Test emails.
|
Test emails.
|
||||||
"""
|
"""
|
||||||
|
|||||||
61
{{cookiecutter.project_slug}}/backend/app/app/api/deps.py
Normal file
61
{{cookiecutter.project_slug}}/backend/app/app/api/deps.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
from typing import Generator
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from jose import jwt
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app import crud, models, schemas
|
||||||
|
from app.core import security
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.db.session import SessionLocal
|
||||||
|
|
||||||
|
reusable_oauth2 = OAuth2PasswordBearer(
|
||||||
|
tokenUrl=f"{settings.API_V1_STR}/login/access-token"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_db() -> Generator:
|
||||||
|
try:
|
||||||
|
db = SessionLocal()
|
||||||
|
yield db
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_user(
|
||||||
|
db: Session = Depends(get_db), token: str = Depends(reusable_oauth2)
|
||||||
|
) -> models.User:
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token, settings.SECRET_KEY, algorithms=[security.ALGORITHM]
|
||||||
|
)
|
||||||
|
token_data = schemas.TokenPayload(**payload)
|
||||||
|
except (jwt.JWTError, ValidationError):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Could not validate credentials",
|
||||||
|
)
|
||||||
|
user = crud.user.get(db, id=token_data.sub)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(status_code=404, detail="User not found")
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_active_user(
|
||||||
|
current_user: models.User = Depends(get_current_user),
|
||||||
|
) -> models.User:
|
||||||
|
if not crud.user.is_active(current_user):
|
||||||
|
raise HTTPException(status_code=400, detail="Inactive user")
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_active_superuser(
|
||||||
|
current_user: models.User = Depends(get_current_user),
|
||||||
|
) -> models.User:
|
||||||
|
if not crud.user.is_superuser(current_user):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="The user doesn't have enough privileges"
|
||||||
|
)
|
||||||
|
return current_user
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
from starlette.requests import Request
|
|
||||||
|
|
||||||
|
|
||||||
def get_db(request: Request):
|
|
||||||
return request.state.db
|
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
import jwt
|
|
||||||
from fastapi import Depends, HTTPException, Security
|
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
|
||||||
from jwt import PyJWTError
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
from starlette.status import HTTP_403_FORBIDDEN
|
|
||||||
|
|
||||||
from app import crud
|
|
||||||
from app.api.utils.db import get_db
|
|
||||||
from app.core.config import settings
|
|
||||||
from app.core.jwt import ALGORITHM
|
|
||||||
from app.models.user import User
|
|
||||||
from app.schemas.token import TokenPayload
|
|
||||||
|
|
||||||
reusable_oauth2 = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/login/access-token")
|
|
||||||
|
|
||||||
|
|
||||||
def get_current_user(
|
|
||||||
db: Session = Depends(get_db), token: str = Security(reusable_oauth2)
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
|
|
||||||
token_data = TokenPayload(**payload)
|
|
||||||
except PyJWTError:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials"
|
|
||||||
)
|
|
||||||
user = crud.user.get(db, id=token_data.user_id)
|
|
||||||
if not user:
|
|
||||||
raise HTTPException(status_code=404, detail="User not found")
|
|
||||||
return user
|
|
||||||
|
|
||||||
|
|
||||||
def get_current_active_user(current_user: User = Security(get_current_user)):
|
|
||||||
if not crud.user.is_active(current_user):
|
|
||||||
raise HTTPException(status_code=400, detail="Inactive user")
|
|
||||||
return current_user
|
|
||||||
|
|
||||||
|
|
||||||
def get_current_active_superuser(current_user: User = Security(get_current_user)):
|
|
||||||
if not crud.user.is_superuser(current_user):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400, detail="The user doesn't have enough privileges"
|
|
||||||
)
|
|
||||||
return current_user
|
|
||||||
@@ -2,7 +2,7 @@ import logging
|
|||||||
|
|
||||||
from tenacity import after_log, before_log, retry, stop_after_attempt, wait_fixed
|
from tenacity import after_log, before_log, retry, stop_after_attempt, wait_fixed
|
||||||
|
|
||||||
from app.db.session import db_session
|
from app.db.session import SessionLocal
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -17,16 +17,17 @@ wait_seconds = 1
|
|||||||
before=before_log(logger, logging.INFO),
|
before=before_log(logger, logging.INFO),
|
||||||
after=after_log(logger, logging.WARN),
|
after=after_log(logger, logging.WARN),
|
||||||
)
|
)
|
||||||
def init():
|
def init() -> None:
|
||||||
try:
|
try:
|
||||||
|
db = SessionLocal()
|
||||||
# Try to create session to check if DB is awake
|
# Try to create session to check if DB is awake
|
||||||
db_session.execute("SELECT 1")
|
db.execute("SELECT 1")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
logger.info("Initializing service")
|
logger.info("Initializing service")
|
||||||
init()
|
init()
|
||||||
logger.info("Service finished initializing")
|
logger.info("Service finished initializing")
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import logging
|
|||||||
|
|
||||||
from tenacity import after_log, before_log, retry, stop_after_attempt, wait_fixed
|
from tenacity import after_log, before_log, retry, stop_after_attempt, wait_fixed
|
||||||
|
|
||||||
from app.db.session import db_session
|
from app.db.session import SessionLocal
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -17,16 +17,17 @@ wait_seconds = 1
|
|||||||
before=before_log(logger, logging.INFO),
|
before=before_log(logger, logging.INFO),
|
||||||
after=after_log(logger, logging.WARN),
|
after=after_log(logger, logging.WARN),
|
||||||
)
|
)
|
||||||
def init():
|
def init() -> None:
|
||||||
try:
|
try:
|
||||||
# Try to create session to check if DB is awake
|
# Try to create session to check if DB is awake
|
||||||
db_session.execute("SELECT 1")
|
db = SessionLocal()
|
||||||
|
db.execute("SELECT 1")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
logger.info("Initializing service")
|
logger.info("Initializing service")
|
||||||
init()
|
init()
|
||||||
logger.info("Service finished initializing")
|
logger.info("Service finished initializing")
|
||||||
|
|||||||
@@ -1,17 +1,14 @@
|
|||||||
import secrets
|
import secrets
|
||||||
from typing import List
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from pydantic import AnyHttpUrl, BaseSettings, EmailStr, HttpUrl, PostgresDsn, validator
|
from pydantic import AnyHttpUrl, BaseSettings, EmailStr, HttpUrl, PostgresDsn, validator
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
|
|
||||||
API_V1_STR: str = "/api/v1"
|
API_V1_STR: str = "/api/v1"
|
||||||
|
|
||||||
SECRET_KEY: str = secrets.token_urlsafe(32)
|
SECRET_KEY: str = secrets.token_urlsafe(32)
|
||||||
|
# 60 minutes * 24 hours * 8 days = 8 days
|
||||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8 # 60 minutes * 24 hours * 8 days = 8 days
|
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8
|
||||||
|
|
||||||
SERVER_NAME: str
|
SERVER_NAME: str
|
||||||
SERVER_HOST: AnyHttpUrl
|
SERVER_HOST: AnyHttpUrl
|
||||||
# BACKEND_CORS_ORIGINS is a JSON-formatted list of origins
|
# BACKEND_CORS_ORIGINS is a JSON-formatted list of origins
|
||||||
@@ -20,16 +17,18 @@ class Settings(BaseSettings):
|
|||||||
BACKEND_CORS_ORIGINS: List[AnyHttpUrl] = []
|
BACKEND_CORS_ORIGINS: List[AnyHttpUrl] = []
|
||||||
|
|
||||||
@validator("BACKEND_CORS_ORIGINS", pre=True)
|
@validator("BACKEND_CORS_ORIGINS", pre=True)
|
||||||
def assemble_cors_origins(cls, v):
|
def assemble_cors_origins(cls, v: Union[str, List[str]]) -> Union[List[str], str]:
|
||||||
if isinstance(v, str) and not v.startswith("["):
|
if isinstance(v, str) and not v.startswith("["):
|
||||||
return [i.strip() for i in v.split(",")]
|
return [i.strip() for i in v.split(",")]
|
||||||
return v
|
elif isinstance(v, (list, str)):
|
||||||
|
return v
|
||||||
|
raise ValueError(v)
|
||||||
|
|
||||||
PROJECT_NAME: str
|
PROJECT_NAME: str
|
||||||
SENTRY_DSN: HttpUrl = None
|
SENTRY_DSN: Optional[HttpUrl] = None
|
||||||
|
|
||||||
@validator("SENTRY_DSN", pre=True)
|
@validator("SENTRY_DSN", pre=True)
|
||||||
def sentry_dsn_can_be_blank(cls, v):
|
def sentry_dsn_can_be_blank(cls, v: str) -> Optional[str]:
|
||||||
if len(v) == 0:
|
if len(v) == 0:
|
||||||
return None
|
return None
|
||||||
return v
|
return v
|
||||||
@@ -38,10 +37,10 @@ class Settings(BaseSettings):
|
|||||||
POSTGRES_USER: str
|
POSTGRES_USER: str
|
||||||
POSTGRES_PASSWORD: str
|
POSTGRES_PASSWORD: str
|
||||||
POSTGRES_DB: str
|
POSTGRES_DB: str
|
||||||
SQLALCHEMY_DATABASE_URI: PostgresDsn = None
|
SQLALCHEMY_DATABASE_URI: Optional[PostgresDsn] = None
|
||||||
|
|
||||||
@validator("SQLALCHEMY_DATABASE_URI", pre=True)
|
@validator("SQLALCHEMY_DATABASE_URI", pre=True)
|
||||||
def assemble_db_connection(cls, v, values):
|
def assemble_db_connection(cls, v: Optional[str], values: Dict[str, Any]) -> Any:
|
||||||
if isinstance(v, str):
|
if isinstance(v, str):
|
||||||
return v
|
return v
|
||||||
return PostgresDsn.build(
|
return PostgresDsn.build(
|
||||||
@@ -53,15 +52,15 @@ class Settings(BaseSettings):
|
|||||||
)
|
)
|
||||||
|
|
||||||
SMTP_TLS: bool = True
|
SMTP_TLS: bool = True
|
||||||
SMTP_PORT: int = None
|
SMTP_PORT: Optional[int] = None
|
||||||
SMTP_HOST: str = None
|
SMTP_HOST: Optional[str] = None
|
||||||
SMTP_USER: str = None
|
SMTP_USER: Optional[str] = None
|
||||||
SMTP_PASSWORD: str = None
|
SMTP_PASSWORD: Optional[str] = None
|
||||||
EMAILS_FROM_EMAIL: EmailStr = None
|
EMAILS_FROM_EMAIL: Optional[EmailStr] = None
|
||||||
EMAILS_FROM_NAME: str = None
|
EMAILS_FROM_NAME: Optional[str] = None
|
||||||
|
|
||||||
@validator("EMAILS_FROM_NAME")
|
@validator("EMAILS_FROM_NAME")
|
||||||
def get_project_name(cls, v, values):
|
def get_project_name(cls, v: Optional[str], values: Dict[str, Any]) -> str:
|
||||||
if not v:
|
if not v:
|
||||||
return values["PROJECT_NAME"]
|
return values["PROJECT_NAME"]
|
||||||
return v
|
return v
|
||||||
@@ -71,18 +70,16 @@ class Settings(BaseSettings):
|
|||||||
EMAILS_ENABLED: bool = False
|
EMAILS_ENABLED: bool = False
|
||||||
|
|
||||||
@validator("EMAILS_ENABLED", pre=True)
|
@validator("EMAILS_ENABLED", pre=True)
|
||||||
def get_emails_enabled(cls, v, values):
|
def get_emails_enabled(cls, v: bool, values: Dict[str, Any]) -> bool:
|
||||||
return bool(
|
return bool(
|
||||||
values.get("SMTP_HOST")
|
values.get("SMTP_HOST")
|
||||||
and values.get("SMTP_PORT")
|
and values.get("SMTP_PORT")
|
||||||
and values.get("EMAILS_FROM_EMAIL")
|
and values.get("EMAILS_FROM_EMAIL")
|
||||||
)
|
)
|
||||||
|
|
||||||
EMAIL_TEST_USER: EmailStr = "test@example.com"
|
EMAIL_TEST_USER: EmailStr = "test@example.com" # type: ignore
|
||||||
|
|
||||||
FIRST_SUPERUSER: EmailStr
|
FIRST_SUPERUSER: EmailStr
|
||||||
FIRST_SUPERUSER_PASSWORD: str
|
FIRST_SUPERUSER_PASSWORD: str
|
||||||
|
|
||||||
USERS_OPEN_REGISTRATION: bool = False
|
USERS_OPEN_REGISTRATION: bool = False
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
|||||||
@@ -1,19 +0,0 @@
|
|||||||
from datetime import datetime, timedelta
|
|
||||||
|
|
||||||
import jwt
|
|
||||||
|
|
||||||
from app.core.config import settings
|
|
||||||
|
|
||||||
ALGORITHM = "HS256"
|
|
||||||
access_token_jwt_subject = "access"
|
|
||||||
|
|
||||||
|
|
||||||
def create_access_token(*, data: dict, expires_delta: timedelta = None):
|
|
||||||
to_encode = data.copy()
|
|
||||||
if expires_delta:
|
|
||||||
expire = datetime.utcnow() + expires_delta
|
|
||||||
else:
|
|
||||||
expire = datetime.utcnow() + timedelta(minutes=15)
|
|
||||||
to_encode.update({"exp": expire, "sub": access_token_jwt_subject})
|
|
||||||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
|
|
||||||
return encoded_jwt
|
|
||||||
@@ -1,11 +1,34 @@
|
|||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
|
from jose import jwt
|
||||||
from passlib.context import CryptContext
|
from passlib.context import CryptContext
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||||
|
|
||||||
|
|
||||||
def verify_password(plain_password: str, hashed_password: str):
|
ALGORITHM = "HS256"
|
||||||
|
|
||||||
|
|
||||||
|
def create_access_token(
|
||||||
|
subject: Union[str, Any], expires_delta: timedelta = None
|
||||||
|
) -> str:
|
||||||
|
if expires_delta:
|
||||||
|
expire = datetime.utcnow() + expires_delta
|
||||||
|
else:
|
||||||
|
expire = datetime.utcnow() + timedelta(
|
||||||
|
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||||
|
)
|
||||||
|
to_encode = {"exp": expire, "sub": str(subject)}
|
||||||
|
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
|
||||||
|
return encoded_jwt
|
||||||
|
|
||||||
|
|
||||||
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||||
return pwd_context.verify(plain_password, hashed_password)
|
return pwd_context.verify(plain_password, hashed_password)
|
||||||
|
|
||||||
|
|
||||||
def get_password_hash(password: str):
|
def get_password_hash(password: str) -> str:
|
||||||
return pwd_context.hash(password)
|
return pwd_context.hash(password)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from .crud_user import user # noqa: F401
|
|
||||||
from .crud_item import item # noqa: F401
|
from .crud_item import item # noqa: F401
|
||||||
|
from .crud_user import user # noqa: F401
|
||||||
|
|
||||||
# For a new basic set of CRUD operations you could just do
|
# For a new basic set of CRUD operations you could just do
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import List, Optional, Generic, TypeVar, Type
|
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union
|
||||||
|
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -23,35 +23,44 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
|||||||
"""
|
"""
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
def get(self, db_session: Session, id: int) -> Optional[ModelType]:
|
def get(self, db: Session, id: Any) -> Optional[ModelType]:
|
||||||
return db_session.query(self.model).filter(self.model.id == id).first()
|
return db.query(self.model).filter(self.model.id == id).first()
|
||||||
|
|
||||||
def get_multi(self, db_session: Session, *, skip=0, limit=100) -> List[ModelType]:
|
def get_multi(
|
||||||
return db_session.query(self.model).offset(skip).limit(limit).all()
|
self, db: Session, *, skip: int = 0, limit: int = 100
|
||||||
|
) -> List[ModelType]:
|
||||||
|
return db.query(self.model).offset(skip).limit(limit).all()
|
||||||
|
|
||||||
def create(self, db_session: Session, *, obj_in: CreateSchemaType) -> ModelType:
|
def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
|
||||||
obj_in_data = jsonable_encoder(obj_in)
|
obj_in_data = jsonable_encoder(obj_in)
|
||||||
db_obj = self.model(**obj_in_data)
|
db_obj = self.model(**obj_in_data) # type: ignore
|
||||||
db_session.add(db_obj)
|
db.add(db_obj)
|
||||||
db_session.commit()
|
db.commit()
|
||||||
db_session.refresh(db_obj)
|
db.refresh(db_obj)
|
||||||
return db_obj
|
return db_obj
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
self, db_session: Session, *, db_obj: ModelType, obj_in: UpdateSchemaType
|
self,
|
||||||
|
db: Session,
|
||||||
|
*,
|
||||||
|
db_obj: ModelType,
|
||||||
|
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
|
||||||
) -> ModelType:
|
) -> ModelType:
|
||||||
obj_data = jsonable_encoder(db_obj)
|
obj_data = jsonable_encoder(db_obj)
|
||||||
update_data = obj_in.dict(exclude_unset=True)
|
if isinstance(obj_in, dict):
|
||||||
|
update_data = obj_in
|
||||||
|
else:
|
||||||
|
update_data = obj_in.dict(exclude_unset=True)
|
||||||
for field in obj_data:
|
for field in obj_data:
|
||||||
if field in update_data:
|
if field in update_data:
|
||||||
setattr(db_obj, field, update_data[field])
|
setattr(db_obj, field, update_data[field])
|
||||||
db_session.add(db_obj)
|
db.add(db_obj)
|
||||||
db_session.commit()
|
db.commit()
|
||||||
db_session.refresh(db_obj)
|
db.refresh(db_obj)
|
||||||
return db_obj
|
return db_obj
|
||||||
|
|
||||||
def remove(self, db_session: Session, *, id: int) -> ModelType:
|
def remove(self, db: Session, *, id: int) -> ModelType:
|
||||||
obj = db_session.query(self.model).get(id)
|
obj = db.query(self.model).get(id)
|
||||||
db_session.delete(obj)
|
db.delete(obj)
|
||||||
db_session.commit()
|
db.commit()
|
||||||
return obj
|
return obj
|
||||||
|
|||||||
@@ -3,27 +3,27 @@ from typing import List
|
|||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.crud.base import CRUDBase
|
||||||
from app.models.item import Item
|
from app.models.item import Item
|
||||||
from app.schemas.item import ItemCreate, ItemUpdate
|
from app.schemas.item import ItemCreate, ItemUpdate
|
||||||
from app.crud.base import CRUDBase
|
|
||||||
|
|
||||||
|
|
||||||
class CRUDItem(CRUDBase[Item, ItemCreate, ItemUpdate]):
|
class CRUDItem(CRUDBase[Item, ItemCreate, ItemUpdate]):
|
||||||
def create_with_owner(
|
def create_with_owner(
|
||||||
self, db_session: Session, *, obj_in: ItemCreate, owner_id: int
|
self, db: Session, *, obj_in: ItemCreate, owner_id: int
|
||||||
) -> Item:
|
) -> Item:
|
||||||
obj_in_data = jsonable_encoder(obj_in)
|
obj_in_data = jsonable_encoder(obj_in)
|
||||||
db_obj = self.model(**obj_in_data, owner_id=owner_id)
|
db_obj = self.model(**obj_in_data, owner_id=owner_id)
|
||||||
db_session.add(db_obj)
|
db.add(db_obj)
|
||||||
db_session.commit()
|
db.commit()
|
||||||
db_session.refresh(db_obj)
|
db.refresh(db_obj)
|
||||||
return db_obj
|
return db_obj
|
||||||
|
|
||||||
def get_multi_by_owner(
|
def get_multi_by_owner(
|
||||||
self, db_session: Session, *, owner_id: int, skip=0, limit=100
|
self, db: Session, *, owner_id: int, skip: int = 0, limit: int = 100
|
||||||
) -> List[Item]:
|
) -> List[Item]:
|
||||||
return (
|
return (
|
||||||
db_session.query(self.model)
|
db.query(self.model)
|
||||||
.filter(Item.owner_id == owner_id)
|
.filter(Item.owner_id == owner_id)
|
||||||
.offset(skip)
|
.offset(skip)
|
||||||
.limit(limit)
|
.limit(limit)
|
||||||
|
|||||||
@@ -1,42 +1,44 @@
|
|||||||
from typing import Optional
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.models.user import User
|
from app.core.security import get_password_hash, verify_password
|
||||||
from app.schemas.user import UserCreate, UserUpdate, UserInDB
|
|
||||||
from app.core.security import verify_password, get_password_hash
|
|
||||||
from app.crud.base import CRUDBase
|
from app.crud.base import CRUDBase
|
||||||
|
from app.models.user import User
|
||||||
|
from app.schemas.user import UserCreate, UserUpdate
|
||||||
|
|
||||||
|
|
||||||
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||||
def get_by_email(self, db_session: Session, *, email: str) -> Optional[User]:
|
def get_by_email(self, db: Session, *, email: str) -> Optional[User]:
|
||||||
return db_session.query(User).filter(User.email == email).first()
|
return db.query(User).filter(User.email == email).first()
|
||||||
|
|
||||||
def create(self, db_session: Session, *, obj_in: UserCreate) -> User:
|
def create(self, db: Session, *, obj_in: UserCreate) -> User:
|
||||||
db_obj = User(
|
db_obj = User(
|
||||||
email=obj_in.email,
|
email=obj_in.email,
|
||||||
hashed_password=get_password_hash(obj_in.password),
|
hashed_password=get_password_hash(obj_in.password),
|
||||||
full_name=obj_in.full_name,
|
full_name=obj_in.full_name,
|
||||||
is_superuser=obj_in.is_superuser,
|
is_superuser=obj_in.is_superuser,
|
||||||
)
|
)
|
||||||
db_session.add(db_obj)
|
db.add(db_obj)
|
||||||
db_session.commit()
|
db.commit()
|
||||||
db_session.refresh(db_obj)
|
db.refresh(db_obj)
|
||||||
return db_obj
|
return db_obj
|
||||||
|
|
||||||
def update(self, db_session: Session, *, db_obj: User, obj_in: UserUpdate) -> User:
|
def update(
|
||||||
if obj_in.password:
|
self, db: Session, *, db_obj: User, obj_in: Union[UserUpdate, Dict[str, Any]]
|
||||||
|
) -> User:
|
||||||
|
if isinstance(obj_in, dict):
|
||||||
|
update_data = obj_in
|
||||||
|
else:
|
||||||
update_data = obj_in.dict(exclude_unset=True)
|
update_data = obj_in.dict(exclude_unset=True)
|
||||||
hashed_password = get_password_hash(obj_in.password)
|
if update_data["password"]:
|
||||||
|
hashed_password = get_password_hash(update_data["password"])
|
||||||
del update_data["password"]
|
del update_data["password"]
|
||||||
update_data["hashed_password"] = hashed_password
|
update_data["hashed_password"] = hashed_password
|
||||||
use_obj_in = UserInDB.parse_obj(update_data)
|
return super().update(db, db_obj=db_obj, obj_in=update_data)
|
||||||
return super().update(db_session, db_obj=db_obj, obj_in=use_obj_in)
|
|
||||||
|
|
||||||
def authenticate(
|
def authenticate(self, db: Session, *, email: str, password: str) -> Optional[User]:
|
||||||
self, db_session: Session, *, email: str, password: str
|
user = self.get_by_email(db, email=email)
|
||||||
) -> Optional[User]:
|
|
||||||
user = self.get_by_email(db_session, email=email)
|
|
||||||
if not user:
|
if not user:
|
||||||
return None
|
return None
|
||||||
if not verify_password(password, user.hashed_password):
|
if not verify_password(password, user.hashed_password):
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
# Import all the models, so that Base has them before being
|
# Import all the models, so that Base has them before being
|
||||||
# imported by Alembic
|
# imported by Alembic
|
||||||
from app.db.base_class import Base # noqa
|
from app.db.base_class import Base # noqa
|
||||||
from app.models.user import User # noqa
|
|
||||||
from app.models.item import Item # noqa
|
from app.models.item import Item # noqa
|
||||||
|
from app.models.user import User # noqa
|
||||||
|
|||||||
@@ -1,9 +1,13 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy.ext.declarative import as_declarative, declared_attr
|
from sqlalchemy.ext.declarative import as_declarative, declared_attr
|
||||||
|
|
||||||
|
|
||||||
@as_declarative()
|
@as_declarative()
|
||||||
class Base:
|
class Base:
|
||||||
|
id: Any
|
||||||
|
__name__: str
|
||||||
# Generate __tablename__ automatically
|
# Generate __tablename__ automatically
|
||||||
@declared_attr
|
@declared_attr
|
||||||
def __tablename__(cls):
|
def __tablename__(cls) -> str:
|
||||||
return cls.__name__.lower()
|
return cls.__name__.lower()
|
||||||
|
|||||||
@@ -1,24 +1,25 @@
|
|||||||
from app import crud
|
from sqlalchemy.orm import Session
|
||||||
from app.core.config import settings
|
|
||||||
from app.schemas.user import UserCreate
|
|
||||||
|
|
||||||
# make sure all SQL Alchemy models are imported before initializing DB
|
from app import crud, schemas
|
||||||
# otherwise, SQL Alchemy might fail to initialize relationships properly
|
from app.core.config import settings
|
||||||
# for more details: https://github.com/tiangolo/full-stack-fastapi-postgresql/issues/28
|
|
||||||
from app.db import base # noqa: F401
|
from app.db import base # noqa: F401
|
||||||
|
|
||||||
|
# make sure all SQL Alchemy models are imported (app.db.base) before initializing DB
|
||||||
|
# otherwise, SQL Alchemy might fail to initialize relationships properly
|
||||||
|
# for more details: https://github.com/tiangolo/full-stack-fastapi-postgresql/issues/28
|
||||||
|
|
||||||
def init_db(db_session):
|
|
||||||
|
def init_db(db: Session) -> None:
|
||||||
# Tables should be created with Alembic migrations
|
# Tables should be created with Alembic migrations
|
||||||
# But if you don't want to use migrations, create
|
# But if you don't want to use migrations, create
|
||||||
# the tables un-commenting the next line
|
# the tables un-commenting the next line
|
||||||
# Base.metadata.create_all(bind=engine)
|
# Base.metadata.create_all(bind=engine)
|
||||||
|
|
||||||
user = crud.user.get_by_email(db_session, email=settings.FIRST_SUPERUSER)
|
user = crud.user.get_by_email(db, email=settings.FIRST_SUPERUSER)
|
||||||
if not user:
|
if not user:
|
||||||
user_in = UserCreate(
|
user_in = schemas.UserCreate(
|
||||||
email=settings.FIRST_SUPERUSER,
|
email=settings.FIRST_SUPERUSER,
|
||||||
password=settings.FIRST_SUPERUSER_PASSWORD,
|
password=settings.FIRST_SUPERUSER_PASSWORD,
|
||||||
is_superuser=True,
|
is_superuser=True,
|
||||||
)
|
)
|
||||||
user = crud.user.create(db_session, obj_in=user_in) # noqa: F841
|
user = crud.user.create(db, obj_in=user_in) # noqa: F841
|
||||||
|
|||||||
@@ -1,10 +1,7 @@
|
|||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.orm import scoped_session, sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
|
||||||
engine = create_engine(settings.SQLALCHEMY_DATABASE_URI, pool_pre_ping=True)
|
engine = create_engine(settings.SQLALCHEMY_DATABASE_URI, pool_pre_ping=True)
|
||||||
db_session = scoped_session(
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
||||||
)
|
|
||||||
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
||||||
|
|||||||
@@ -1,17 +1,18 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from app.db.init_db import init_db
|
from app.db.init_db import init_db
|
||||||
from app.db.session import db_session
|
from app.db.session import SessionLocal
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def init():
|
def init() -> None:
|
||||||
init_db(db_session)
|
db = SessionLocal()
|
||||||
|
init_db(db)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
logger.info("Creating initial data")
|
logger.info("Creating initial data")
|
||||||
init()
|
init()
|
||||||
logger.info("Initial data created")
|
logger.info("Initial data created")
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from starlette.middleware.cors import CORSMiddleware
|
from starlette.middleware.cors import CORSMiddleware
|
||||||
from starlette.requests import Request
|
|
||||||
|
|
||||||
from app.api.api_v1.api import api_router
|
from app.api.api_v1.api import api_router
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.db.session import Session
|
|
||||||
|
|
||||||
app = FastAPI(title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V1_STR}/openapi.json")
|
app = FastAPI(
|
||||||
|
title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V1_STR}/openapi.json"
|
||||||
|
)
|
||||||
|
|
||||||
# Set all CORS enabled origins
|
# Set all CORS enabled origins
|
||||||
if settings.BACKEND_CORS_ORIGINS:
|
if settings.BACKEND_CORS_ORIGINS:
|
||||||
@@ -16,14 +16,6 @@ if settings.BACKEND_CORS_ORIGINS:
|
|||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
),
|
)
|
||||||
|
|
||||||
app.include_router(api_router, prefix=settings.API_V1_STR)
|
app.include_router(api_router, prefix=settings.API_V1_STR)
|
||||||
|
|
||||||
|
|
||||||
@app.middleware("http")
|
|
||||||
async def db_session_middleware(request: Request, call_next):
|
|
||||||
request.state.db = Session()
|
|
||||||
response = await call_next(request)
|
|
||||||
request.state.db.close()
|
|
||||||
return response
|
|
||||||
|
|||||||
@@ -0,0 +1,2 @@
|
|||||||
|
from .item import Item
|
||||||
|
from .user import User
|
||||||
|
|||||||
@@ -1,8 +1,13 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from sqlalchemy import Column, ForeignKey, Integer, String
|
from sqlalchemy import Column, ForeignKey, Integer, String
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
from app.db.base_class import Base
|
from app.db.base_class import Base
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .user import User # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
class Item(Base):
|
class Item(Base):
|
||||||
id = Column(Integer, primary_key=True, index=True)
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
|||||||
@@ -1,14 +1,19 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from sqlalchemy import Boolean, Column, Integer, String
|
from sqlalchemy import Boolean, Column, Integer, String
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
from app.db.base_class import Base
|
from app.db.base_class import Base
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .item import Item # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
class User(Base):
|
class User(Base):
|
||||||
id = Column(Integer, primary_key=True, index=True)
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
full_name = Column(String, index=True)
|
full_name = Column(String, index=True)
|
||||||
email = Column(String, unique=True, index=True)
|
email = Column(String, unique=True, index=True, nullable=False)
|
||||||
hashed_password = Column(String)
|
hashed_password = Column(String, nullable=False)
|
||||||
is_active = Column(Boolean(), default=True)
|
is_active = Column(Boolean(), default=True)
|
||||||
is_superuser = Column(Boolean(), default=False)
|
is_superuser = Column(Boolean(), default=False)
|
||||||
items = relationship("Item", back_populates="owner")
|
items = relationship("Item", back_populates="owner")
|
||||||
|
|||||||
@@ -0,0 +1,4 @@
|
|||||||
|
from .item import Item, ItemCreate, ItemInDB, ItemUpdate
|
||||||
|
from .msg import Msg
|
||||||
|
from .token import Token, TokenPayload
|
||||||
|
from .user import User, UserCreate, UserInDB, UserUpdate
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from .user import User # noqa: F401
|
from .user import User # noqa: F401
|
||||||
@@ -5,8 +7,8 @@ from .user import User # noqa: F401
|
|||||||
|
|
||||||
# Shared properties
|
# Shared properties
|
||||||
class ItemBase(BaseModel):
|
class ItemBase(BaseModel):
|
||||||
title: str = None
|
title: Optional[str] = None
|
||||||
description: str = None
|
description: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
# Properties to receive on item creation
|
# Properties to receive on item creation
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
@@ -7,4 +9,4 @@ class Token(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class TokenPayload(BaseModel):
|
class TokenPayload(BaseModel):
|
||||||
user_id: int = None
|
sub: Optional[int] = None
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from pydantic import BaseModel, EmailStr
|
|||||||
class UserBase(BaseModel):
|
class UserBase(BaseModel):
|
||||||
email: Optional[EmailStr] = None
|
email: Optional[EmailStr] = None
|
||||||
is_active: Optional[bool] = True
|
is_active: Optional[bool] = True
|
||||||
is_superuser: Optional[bool] = False
|
is_superuser: bool = False
|
||||||
full_name: Optional[str] = None
|
full_name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@@ -23,7 +23,7 @@ class UserUpdate(UserBase):
|
|||||||
|
|
||||||
|
|
||||||
class UserInDBBase(UserBase):
|
class UserInDBBase(UserBase):
|
||||||
id: int = None
|
id: Optional[int] = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
orm_mode = True
|
orm_mode = True
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.tests.utils.utils import get_server_api
|
from app.tests.utils.utils import get_server_api
|
||||||
|
|
||||||
|
|
||||||
def test_celery_worker_test(superuser_token_headers):
|
def test_celery_worker_test(superuser_token_headers: Dict[str, str]) -> None:
|
||||||
server_api = get_server_api()
|
server_api = get_server_api()
|
||||||
data = {"msg": "test"}
|
data = {"msg": "test"}
|
||||||
r = requests.post(
|
r = requests.post(
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
import requests
|
import requests
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.tests.utils.item import create_random_item
|
from app.tests.utils.item import create_random_item
|
||||||
from app.tests.utils.utils import get_server_api
|
|
||||||
from app.tests.utils.user import create_random_user # noqa: F401
|
from app.tests.utils.user import create_random_user # noqa: F401
|
||||||
|
from app.tests.utils.utils import get_server_api
|
||||||
|
|
||||||
|
|
||||||
def test_create_item(superuser_token_headers):
|
def test_create_item(superuser_token_headers: dict, db: Session) -> None:
|
||||||
server_api = get_server_api()
|
server_api = get_server_api()
|
||||||
data = {"title": "Foo", "description": "Fighters"}
|
data = {"title": "Foo", "description": "Fighters"}
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
@@ -22,8 +23,8 @@ def test_create_item(superuser_token_headers):
|
|||||||
assert "owner_id" in content
|
assert "owner_id" in content
|
||||||
|
|
||||||
|
|
||||||
def test_read_item(superuser_token_headers):
|
def test_read_item(superuser_token_headers: dict, db: Session) -> None:
|
||||||
item = create_random_item()
|
item = create_random_item(db)
|
||||||
server_api = get_server_api()
|
server_api = get_server_api()
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
f"{server_api}{settings.API_V1_STR}/items/{item.id}",
|
f"{server_api}{settings.API_V1_STR}/items/{item.id}",
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.tests.utils.utils import get_server_api
|
from app.tests.utils.utils import get_server_api
|
||||||
|
|
||||||
|
|
||||||
def test_get_access_token():
|
def test_get_access_token() -> None:
|
||||||
server_api = get_server_api()
|
server_api = get_server_api()
|
||||||
login_data = {
|
login_data = {
|
||||||
"username": settings.FIRST_SUPERUSER,
|
"username": settings.FIRST_SUPERUSER,
|
||||||
@@ -19,7 +21,7 @@ def test_get_access_token():
|
|||||||
assert tokens["access_token"]
|
assert tokens["access_token"]
|
||||||
|
|
||||||
|
|
||||||
def test_use_access_token(superuser_token_headers):
|
def test_use_access_token(superuser_token_headers: Dict[str, str]) -> None:
|
||||||
server_api = get_server_api()
|
server_api = get_server_api()
|
||||||
r = requests.post(
|
r = requests.post(
|
||||||
f"{server_api}{settings.API_V1_STR}/login/test-token",
|
f"{server_api}{settings.API_V1_STR}/login/test-token",
|
||||||
|
|||||||
@@ -1,13 +1,15 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app import crud
|
from app import crud
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.db.session import db_session
|
|
||||||
from app.schemas.user import UserCreate
|
from app.schemas.user import UserCreate
|
||||||
from app.tests.utils.utils import get_server_api, random_lower_string, random_email
|
from app.tests.utils.utils import get_server_api, random_email, random_lower_string
|
||||||
|
|
||||||
|
|
||||||
def test_get_users_superuser_me(superuser_token_headers):
|
def test_get_users_superuser_me(superuser_token_headers: Dict[str, str]) -> None:
|
||||||
server_api = get_server_api()
|
server_api = get_server_api()
|
||||||
r = requests.get(
|
r = requests.get(
|
||||||
f"{server_api}{settings.API_V1_STR}/users/me", headers=superuser_token_headers
|
f"{server_api}{settings.API_V1_STR}/users/me", headers=superuser_token_headers
|
||||||
@@ -19,7 +21,7 @@ def test_get_users_superuser_me(superuser_token_headers):
|
|||||||
assert current_user["email"] == settings.FIRST_SUPERUSER
|
assert current_user["email"] == settings.FIRST_SUPERUSER
|
||||||
|
|
||||||
|
|
||||||
def test_get_users_normal_user_me(normal_user_token_headers):
|
def test_get_users_normal_user_me(normal_user_token_headers: Dict[str, str]) -> None:
|
||||||
server_api = get_server_api()
|
server_api = get_server_api()
|
||||||
r = requests.get(
|
r = requests.get(
|
||||||
f"{server_api}{settings.API_V1_STR}/users/me", headers=normal_user_token_headers
|
f"{server_api}{settings.API_V1_STR}/users/me", headers=normal_user_token_headers
|
||||||
@@ -31,7 +33,7 @@ def test_get_users_normal_user_me(normal_user_token_headers):
|
|||||||
assert current_user["email"] == settings.EMAIL_TEST_USER
|
assert current_user["email"] == settings.EMAIL_TEST_USER
|
||||||
|
|
||||||
|
|
||||||
def test_create_user_new_email(superuser_token_headers):
|
def test_create_user_new_email(superuser_token_headers: dict, db: Session) -> None:
|
||||||
server_api = get_server_api()
|
server_api = get_server_api()
|
||||||
username = random_email()
|
username = random_email()
|
||||||
password = random_lower_string()
|
password = random_lower_string()
|
||||||
@@ -43,16 +45,17 @@ def test_create_user_new_email(superuser_token_headers):
|
|||||||
)
|
)
|
||||||
assert 200 <= r.status_code < 300
|
assert 200 <= r.status_code < 300
|
||||||
created_user = r.json()
|
created_user = r.json()
|
||||||
user = crud.user.get_by_email(db_session, email=username)
|
user = crud.user.get_by_email(db, email=username)
|
||||||
|
assert user
|
||||||
assert user.email == created_user["email"]
|
assert user.email == created_user["email"]
|
||||||
|
|
||||||
|
|
||||||
def test_get_existing_user(superuser_token_headers):
|
def test_get_existing_user(superuser_token_headers: dict, db: Session) -> None:
|
||||||
server_api = get_server_api()
|
server_api = get_server_api()
|
||||||
username = random_email()
|
username = random_email()
|
||||||
password = random_lower_string()
|
password = random_lower_string()
|
||||||
user_in = UserCreate(email=username, password=password)
|
user_in = UserCreate(email=username, password=password)
|
||||||
user = crud.user.create(db_session, obj_in=user_in)
|
user = crud.user.create(db, obj_in=user_in)
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
r = requests.get(
|
r = requests.get(
|
||||||
f"{server_api}{settings.API_V1_STR}/users/{user_id}",
|
f"{server_api}{settings.API_V1_STR}/users/{user_id}",
|
||||||
@@ -60,17 +63,20 @@ def test_get_existing_user(superuser_token_headers):
|
|||||||
)
|
)
|
||||||
assert 200 <= r.status_code < 300
|
assert 200 <= r.status_code < 300
|
||||||
api_user = r.json()
|
api_user = r.json()
|
||||||
user = crud.user.get_by_email(db_session, email=username)
|
existing_user = crud.user.get_by_email(db, email=username)
|
||||||
assert user.email == api_user["email"]
|
assert existing_user
|
||||||
|
assert existing_user.email == api_user["email"]
|
||||||
|
|
||||||
|
|
||||||
def test_create_user_existing_username(superuser_token_headers):
|
def test_create_user_existing_username(
|
||||||
|
superuser_token_headers: dict, db: Session
|
||||||
|
) -> None:
|
||||||
server_api = get_server_api()
|
server_api = get_server_api()
|
||||||
username = random_email()
|
username = random_email()
|
||||||
# username = email
|
# username = email
|
||||||
password = random_lower_string()
|
password = random_lower_string()
|
||||||
user_in = UserCreate(email=username, password=password)
|
user_in = UserCreate(email=username, password=password)
|
||||||
crud.user.create(db_session, obj_in=user_in)
|
crud.user.create(db, obj_in=user_in)
|
||||||
data = {"email": username, "password": password}
|
data = {"email": username, "password": password}
|
||||||
r = requests.post(
|
r = requests.post(
|
||||||
f"{server_api}{settings.API_V1_STR}/users/",
|
f"{server_api}{settings.API_V1_STR}/users/",
|
||||||
@@ -82,7 +88,7 @@ def test_create_user_existing_username(superuser_token_headers):
|
|||||||
assert "_id" not in created_user
|
assert "_id" not in created_user
|
||||||
|
|
||||||
|
|
||||||
def test_create_user_by_normal_user(normal_user_token_headers):
|
def test_create_user_by_normal_user(normal_user_token_headers: Dict[str, str]) -> None:
|
||||||
server_api = get_server_api()
|
server_api = get_server_api()
|
||||||
username = random_email()
|
username = random_email()
|
||||||
password = random_lower_string()
|
password = random_lower_string()
|
||||||
@@ -95,17 +101,17 @@ def test_create_user_by_normal_user(normal_user_token_headers):
|
|||||||
assert r.status_code == 400
|
assert r.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
def test_retrieve_users(superuser_token_headers):
|
def test_retrieve_users(superuser_token_headers: dict, db: Session) -> None:
|
||||||
server_api = get_server_api()
|
server_api = get_server_api()
|
||||||
username = random_email()
|
username = random_email()
|
||||||
password = random_lower_string()
|
password = random_lower_string()
|
||||||
user_in = UserCreate(email=username, password=password)
|
user_in = UserCreate(email=username, password=password)
|
||||||
user = crud.user.create(db_session, obj_in=user_in)
|
crud.user.create(db, obj_in=user_in)
|
||||||
|
|
||||||
username2 = random_email()
|
username2 = random_email()
|
||||||
password2 = random_lower_string()
|
password2 = random_lower_string()
|
||||||
user_in2 = UserCreate(email=username2, password=password2)
|
user_in2 = UserCreate(email=username2, password=password2)
|
||||||
crud.user.create(db_session, obj_in=user_in2)
|
crud.user.create(db, obj_in=user_in2)
|
||||||
|
|
||||||
r = requests.get(
|
r = requests.get(
|
||||||
f"{server_api}{settings.API_V1_STR}/users/", headers=superuser_token_headers
|
f"{server_api}{settings.API_V1_STR}/users/", headers=superuser_token_headers
|
||||||
@@ -113,5 +119,5 @@ def test_retrieve_users(superuser_token_headers):
|
|||||||
all_users = r.json()
|
all_users = r.json()
|
||||||
|
|
||||||
assert len(all_users) > 1
|
assert len(all_users) > 1
|
||||||
for user in all_users:
|
for item in all_users:
|
||||||
assert "email" in user
|
assert "email" in item
|
||||||
|
|||||||
@@ -1,20 +1,29 @@
|
|||||||
|
from typing import Dict, Iterator
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.tests.utils.utils import get_server_api, get_superuser_token_headers
|
from app.db.session import SessionLocal
|
||||||
from app.tests.utils.user import authentication_token_from_email
|
from app.tests.utils.user import authentication_token_from_email
|
||||||
|
from app.tests.utils.utils import get_server_api, get_superuser_token_headers
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def db() -> Iterator[Session]:
|
||||||
|
yield SessionLocal()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def server_api():
|
def server_api() -> str:
|
||||||
return get_server_api()
|
return get_server_api()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def superuser_token_headers():
|
def superuser_token_headers() -> Dict[str, str]:
|
||||||
return get_superuser_token_headers()
|
return get_superuser_token_headers()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def normal_user_token_headers():
|
def normal_user_token_headers(db: Session) -> Dict[str, str]:
|
||||||
return authentication_token_from_email(settings.EMAIL_TEST_USER)
|
return authentication_token_from_email(email=settings.EMAIL_TEST_USER, db=db)
|
||||||
|
|||||||
@@ -1,65 +1,59 @@
|
|||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app import crud
|
from app import crud
|
||||||
from app.schemas.item import ItemCreate, ItemUpdate
|
from app.schemas.item import ItemCreate, ItemUpdate
|
||||||
from app.tests.utils.user import create_random_user
|
from app.tests.utils.user import create_random_user
|
||||||
from app.tests.utils.utils import random_lower_string
|
from app.tests.utils.utils import random_lower_string
|
||||||
from app.db.session import db_session
|
|
||||||
|
|
||||||
|
|
||||||
def test_create_item():
|
def test_create_item(db: Session) -> None:
|
||||||
title = random_lower_string()
|
title = random_lower_string()
|
||||||
description = random_lower_string()
|
description = random_lower_string()
|
||||||
item_in = ItemCreate(title=title, description=description)
|
item_in = ItemCreate(title=title, description=description)
|
||||||
user = create_random_user()
|
user = create_random_user(db)
|
||||||
item = crud.item.create_with_owner(
|
item = crud.item.create_with_owner(db=db, obj_in=item_in, owner_id=user.id)
|
||||||
db_session=db_session, obj_in=item_in, owner_id=user.id
|
|
||||||
)
|
|
||||||
assert item.title == title
|
assert item.title == title
|
||||||
assert item.description == description
|
assert item.description == description
|
||||||
assert item.owner_id == user.id
|
assert item.owner_id == user.id
|
||||||
|
|
||||||
|
|
||||||
def test_get_item():
|
def test_get_item(db: Session) -> None:
|
||||||
title = random_lower_string()
|
title = random_lower_string()
|
||||||
description = random_lower_string()
|
description = random_lower_string()
|
||||||
item_in = ItemCreate(title=title, description=description)
|
item_in = ItemCreate(title=title, description=description)
|
||||||
user = create_random_user()
|
user = create_random_user(db)
|
||||||
item = crud.item.create_with_owner(
|
item = crud.item.create_with_owner(db=db, obj_in=item_in, owner_id=user.id)
|
||||||
db_session=db_session, obj_in=item_in, owner_id=user.id
|
stored_item = crud.item.get(db=db, id=item.id)
|
||||||
)
|
assert stored_item
|
||||||
stored_item = crud.item.get(db_session=db_session, id=item.id)
|
|
||||||
assert item.id == stored_item.id
|
assert item.id == stored_item.id
|
||||||
assert item.title == stored_item.title
|
assert item.title == stored_item.title
|
||||||
assert item.description == stored_item.description
|
assert item.description == stored_item.description
|
||||||
assert item.owner_id == stored_item.owner_id
|
assert item.owner_id == stored_item.owner_id
|
||||||
|
|
||||||
|
|
||||||
def test_update_item():
|
def test_update_item(db: Session) -> None:
|
||||||
title = random_lower_string()
|
title = random_lower_string()
|
||||||
description = random_lower_string()
|
description = random_lower_string()
|
||||||
item_in = ItemCreate(title=title, description=description)
|
item_in = ItemCreate(title=title, description=description)
|
||||||
user = create_random_user()
|
user = create_random_user(db)
|
||||||
item = crud.item.create_with_owner(
|
item = crud.item.create_with_owner(db=db, obj_in=item_in, owner_id=user.id)
|
||||||
db_session=db_session, obj_in=item_in, owner_id=user.id
|
|
||||||
)
|
|
||||||
description2 = random_lower_string()
|
description2 = random_lower_string()
|
||||||
item_update = ItemUpdate(description=description2)
|
item_update = ItemUpdate(description=description2)
|
||||||
item2 = crud.item.update(db_session=db_session, db_obj=item, obj_in=item_update)
|
item2 = crud.item.update(db=db, db_obj=item, obj_in=item_update)
|
||||||
assert item.id == item2.id
|
assert item.id == item2.id
|
||||||
assert item.title == item2.title
|
assert item.title == item2.title
|
||||||
assert item2.description == description2
|
assert item2.description == description2
|
||||||
assert item.owner_id == item2.owner_id
|
assert item.owner_id == item2.owner_id
|
||||||
|
|
||||||
|
|
||||||
def test_delete_item():
|
def test_delete_item(db: Session) -> None:
|
||||||
title = random_lower_string()
|
title = random_lower_string()
|
||||||
description = random_lower_string()
|
description = random_lower_string()
|
||||||
item_in = ItemCreate(title=title, description=description)
|
item_in = ItemCreate(title=title, description=description)
|
||||||
user = create_random_user()
|
user = create_random_user(db)
|
||||||
item = crud.item.create_with_owner(
|
item = crud.item.create_with_owner(db=db, obj_in=item_in, owner_id=user.id)
|
||||||
db_session=db_session, obj_in=item_in, owner_id=user.id
|
item2 = crud.item.remove(db=db, id=item.id)
|
||||||
)
|
item3 = crud.item.get(db=db, id=item.id)
|
||||||
item2 = crud.item.remove(db_session=db_session, id=item.id)
|
|
||||||
item3 = crud.item.get(db_session=db_session, id=item.id)
|
|
||||||
assert item3 is None
|
assert item3 is None
|
||||||
assert item2.id == item.id
|
assert item2.id == item.id
|
||||||
assert item2.title == title
|
assert item2.title == title
|
||||||
|
|||||||
@@ -1,94 +1,94 @@
|
|||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app import crud
|
from app import crud
|
||||||
from app.core.security import get_password_hash, verify_password
|
from app.core.security import verify_password
|
||||||
from app.db.session import db_session
|
|
||||||
from app.schemas.user import UserCreate, UserUpdate
|
from app.schemas.user import UserCreate, UserUpdate
|
||||||
from app.tests.utils.utils import random_lower_string, random_email
|
from app.tests.utils.utils import random_email, random_lower_string
|
||||||
|
|
||||||
|
|
||||||
def test_create_user():
|
def test_create_user(db: Session) -> None:
|
||||||
email = random_email()
|
email = random_email()
|
||||||
password = random_lower_string()
|
password = random_lower_string()
|
||||||
user_in = UserCreate(email=email, password=password)
|
user_in = UserCreate(email=email, password=password)
|
||||||
user = crud.user.create(db_session, obj_in=user_in)
|
user = crud.user.create(db, obj_in=user_in)
|
||||||
assert user.email == email
|
assert user.email == email
|
||||||
assert hasattr(user, "hashed_password")
|
assert hasattr(user, "hashed_password")
|
||||||
|
|
||||||
|
|
||||||
def test_authenticate_user():
|
def test_authenticate_user(db: Session) -> None:
|
||||||
email = random_email()
|
email = random_email()
|
||||||
password = random_lower_string()
|
password = random_lower_string()
|
||||||
user_in = UserCreate(email=email, password=password)
|
user_in = UserCreate(email=email, password=password)
|
||||||
user = crud.user.create(db_session, obj_in=user_in)
|
user = crud.user.create(db, obj_in=user_in)
|
||||||
authenticated_user = crud.user.authenticate(
|
authenticated_user = crud.user.authenticate(db, email=email, password=password)
|
||||||
db_session, email=email, password=password
|
|
||||||
)
|
|
||||||
assert authenticated_user
|
assert authenticated_user
|
||||||
assert user.email == authenticated_user.email
|
assert user.email == authenticated_user.email
|
||||||
|
|
||||||
|
|
||||||
def test_not_authenticate_user():
|
def test_not_authenticate_user(db: Session) -> None:
|
||||||
email = random_email()
|
email = random_email()
|
||||||
password = random_lower_string()
|
password = random_lower_string()
|
||||||
user = crud.user.authenticate(db_session, email=email, password=password)
|
user = crud.user.authenticate(db, email=email, password=password)
|
||||||
assert user is None
|
assert user is None
|
||||||
|
|
||||||
|
|
||||||
def test_check_if_user_is_active():
|
def test_check_if_user_is_active(db: Session) -> None:
|
||||||
email = random_email()
|
email = random_email()
|
||||||
password = random_lower_string()
|
password = random_lower_string()
|
||||||
user_in = UserCreate(email=email, password=password)
|
user_in = UserCreate(email=email, password=password)
|
||||||
user = crud.user.create(db_session, obj_in=user_in)
|
user = crud.user.create(db, obj_in=user_in)
|
||||||
is_active = crud.user.is_active(user)
|
is_active = crud.user.is_active(user)
|
||||||
assert is_active is True
|
assert is_active is True
|
||||||
|
|
||||||
|
|
||||||
def test_check_if_user_is_active_inactive():
|
def test_check_if_user_is_active_inactive(db: Session) -> None:
|
||||||
email = random_email()
|
email = random_email()
|
||||||
password = random_lower_string()
|
password = random_lower_string()
|
||||||
user_in = UserCreate(email=email, password=password, disabled=True)
|
user_in = UserCreate(email=email, password=password, disabled=True)
|
||||||
user = crud.user.create(db_session, obj_in=user_in)
|
user = crud.user.create(db, obj_in=user_in)
|
||||||
is_active = crud.user.is_active(user)
|
is_active = crud.user.is_active(user)
|
||||||
assert is_active
|
assert is_active
|
||||||
|
|
||||||
|
|
||||||
def test_check_if_user_is_superuser():
|
def test_check_if_user_is_superuser(db: Session) -> None:
|
||||||
email = random_email()
|
email = random_email()
|
||||||
password = random_lower_string()
|
password = random_lower_string()
|
||||||
user_in = UserCreate(email=email, password=password, is_superuser=True)
|
user_in = UserCreate(email=email, password=password, is_superuser=True)
|
||||||
user = crud.user.create(db_session, obj_in=user_in)
|
user = crud.user.create(db, obj_in=user_in)
|
||||||
is_superuser = crud.user.is_superuser(user)
|
is_superuser = crud.user.is_superuser(user)
|
||||||
assert is_superuser is True
|
assert is_superuser is True
|
||||||
|
|
||||||
|
|
||||||
def test_check_if_user_is_superuser_normal_user():
|
def test_check_if_user_is_superuser_normal_user(db: Session) -> None:
|
||||||
username = random_email()
|
username = random_email()
|
||||||
password = random_lower_string()
|
password = random_lower_string()
|
||||||
user_in = UserCreate(email=username, password=password)
|
user_in = UserCreate(email=username, password=password)
|
||||||
user = crud.user.create(db_session, obj_in=user_in)
|
user = crud.user.create(db, obj_in=user_in)
|
||||||
is_superuser = crud.user.is_superuser(user)
|
is_superuser = crud.user.is_superuser(user)
|
||||||
assert is_superuser is False
|
assert is_superuser is False
|
||||||
|
|
||||||
|
|
||||||
def test_get_user():
|
def test_get_user(db: Session) -> None:
|
||||||
password = random_lower_string()
|
password = random_lower_string()
|
||||||
username = random_email()
|
username = random_email()
|
||||||
user_in = UserCreate(email=username, password=password, is_superuser=True)
|
user_in = UserCreate(email=username, password=password, is_superuser=True)
|
||||||
user = crud.user.create(db_session, obj_in=user_in)
|
user = crud.user.create(db, obj_in=user_in)
|
||||||
user_2 = crud.user.get(db_session, id=user.id)
|
user_2 = crud.user.get(db, id=user.id)
|
||||||
|
assert user_2
|
||||||
assert user.email == user_2.email
|
assert user.email == user_2.email
|
||||||
assert jsonable_encoder(user) == jsonable_encoder(user_2)
|
assert jsonable_encoder(user) == jsonable_encoder(user_2)
|
||||||
|
|
||||||
|
|
||||||
def test_update_user():
|
def test_update_user(db: Session) -> None:
|
||||||
password = random_lower_string()
|
password = random_lower_string()
|
||||||
email = random_email()
|
email = random_email()
|
||||||
user_in = UserCreate(email=email, password=password, is_superuser=True)
|
user_in = UserCreate(email=email, password=password, is_superuser=True)
|
||||||
user = crud.user.create(db_session, obj_in=user_in)
|
user = crud.user.create(db, obj_in=user_in)
|
||||||
new_password = random_lower_string()
|
new_password = random_lower_string()
|
||||||
user_in = UserUpdate(password=new_password, is_superuser=True)
|
user_in_update = UserUpdate(password=new_password, is_superuser=True)
|
||||||
crud.user.update(db_session, db_obj=user, obj_in=user_in)
|
crud.user.update(db, db_obj=user, obj_in=user_in_update)
|
||||||
user_2 = crud.user.get(db_session, id=user.id)
|
user_2 = crud.user.get(db, id=user.id)
|
||||||
|
assert user_2
|
||||||
assert user.email == user_2.email
|
assert user.email == user_2.email
|
||||||
assert verify_password(new_password, user_2.hashed_password)
|
assert verify_password(new_password, user_2.hashed_password)
|
||||||
|
|||||||
@@ -1,17 +1,18 @@
|
|||||||
from app import crud
|
from typing import Optional
|
||||||
from app.db.session import db_session
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app import crud, models
|
||||||
from app.schemas.item import ItemCreate
|
from app.schemas.item import ItemCreate
|
||||||
from app.tests.utils.user import create_random_user
|
from app.tests.utils.user import create_random_user
|
||||||
from app.tests.utils.utils import random_lower_string
|
from app.tests.utils.utils import random_lower_string
|
||||||
|
|
||||||
|
|
||||||
def create_random_item(owner_id: int = None):
|
def create_random_item(db: Session, *, owner_id: Optional[int] = None) -> models.Item:
|
||||||
if owner_id is None:
|
if owner_id is None:
|
||||||
user = create_random_user()
|
user = create_random_user(db)
|
||||||
owner_id = user.id
|
owner_id = user.id
|
||||||
title = random_lower_string()
|
title = random_lower_string()
|
||||||
description = random_lower_string()
|
description = random_lower_string()
|
||||||
item_in = ItemCreate(title=title, description=description, id=id)
|
item_in = ItemCreate(title=title, description=description, id=id)
|
||||||
return crud.item.create_with_owner(
|
return crud.item.create_with_owner(db=db, obj_in=item_in, owner_id=owner_id)
|
||||||
db_session=db_session, obj_in=item_in, owner_id=owner_id
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,43 +1,50 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app import crud
|
from app import crud
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.db.session import db_session
|
from app.models.user import User
|
||||||
from app.schemas.user import UserCreate, UserUpdate
|
from app.schemas.user import UserCreate, UserUpdate
|
||||||
from app.tests.utils.utils import get_server_api, random_lower_string, random_email
|
from app.tests.utils.utils import get_server_api, random_email, random_lower_string
|
||||||
|
|
||||||
|
|
||||||
def user_authentication_headers(server_api, email, password):
|
def user_authentication_headers(
|
||||||
|
server_api: str, email: str, password: str
|
||||||
|
) -> Dict[str, str]:
|
||||||
data = {"username": email, "password": password}
|
data = {"username": email, "password": password}
|
||||||
|
|
||||||
r = requests.post(f"{server_api}{settings.API_V1_STR}/login/access-token", data=data)
|
r = requests.post(
|
||||||
|
f"{server_api}{settings.API_V1_STR}/login/access-token", data=data
|
||||||
|
)
|
||||||
response = r.json()
|
response = r.json()
|
||||||
auth_token = response["access_token"]
|
auth_token = response["access_token"]
|
||||||
headers = {"Authorization": f"Bearer {auth_token}"}
|
headers = {"Authorization": f"Bearer {auth_token}"}
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
|
|
||||||
def create_random_user():
|
def create_random_user(db: Session) -> User:
|
||||||
email = random_email()
|
email = random_email()
|
||||||
password = random_lower_string()
|
password = random_lower_string()
|
||||||
user_in = UserCreate(username=email, email=email, password=password)
|
user_in = UserCreate(username=email, email=email, password=password)
|
||||||
user = crud.user.create(db_session=db_session, obj_in=user_in)
|
user = crud.user.create(db=db, obj_in=user_in)
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
def authentication_token_from_email(email):
|
def authentication_token_from_email(*, email: str, db: Session) -> Dict[str, str]:
|
||||||
"""
|
"""
|
||||||
Return a valid token for the user with given email.
|
Return a valid token for the user with given email.
|
||||||
|
|
||||||
If the user doesn't exist it is created first.
|
If the user doesn't exist it is created first.
|
||||||
"""
|
"""
|
||||||
password = random_lower_string()
|
password = random_lower_string()
|
||||||
user = crud.user.get_by_email(db_session, email=email)
|
user = crud.user.get_by_email(db, email=email)
|
||||||
if not user:
|
if not user:
|
||||||
user_in = UserCreate(username=email, email=email, password=password)
|
user_in_create = UserCreate(username=email, email=email, password=password)
|
||||||
user = crud.user.create(db_session=db_session, obj_in=user_in)
|
user = crud.user.create(db, obj_in=user_in_create)
|
||||||
else:
|
else:
|
||||||
user_in = UserUpdate(password=password)
|
user_in_update = UserUpdate(password=password)
|
||||||
user = crud.user.update(db_session, db_obj=user, obj_in=user_in)
|
user = crud.user.update(db, db_obj=user, obj_in=user_in_update)
|
||||||
|
|
||||||
return user_authentication_headers(get_server_api(), email, password)
|
return user_authentication_headers(get_server_api(), email, password)
|
||||||
|
|||||||
@@ -1,25 +1,26 @@
|
|||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
|
||||||
|
|
||||||
def random_lower_string():
|
def random_lower_string() -> str:
|
||||||
return "".join(random.choices(string.ascii_lowercase, k=32))
|
return "".join(random.choices(string.ascii_lowercase, k=32))
|
||||||
|
|
||||||
|
|
||||||
def random_email():
|
def random_email() -> str:
|
||||||
return f"{random_lower_string()}@{random_lower_string()}.com"
|
return f"{random_lower_string()}@{random_lower_string()}.com"
|
||||||
|
|
||||||
|
|
||||||
def get_server_api():
|
def get_server_api() -> str:
|
||||||
server_name = f"http://{settings.SERVER_NAME}"
|
server_name = f"http://{settings.SERVER_NAME}"
|
||||||
return server_name
|
return server_name
|
||||||
|
|
||||||
|
|
||||||
def get_superuser_token_headers():
|
def get_superuser_token_headers() -> Dict[str, str]:
|
||||||
server_api = get_server_api()
|
server_api = get_server_api()
|
||||||
login_data = {
|
login_data = {
|
||||||
"username": settings.FIRST_SUPERUSER,
|
"username": settings.FIRST_SUPERUSER,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import logging
|
|||||||
|
|
||||||
from tenacity import after_log, before_log, retry, stop_after_attempt, wait_fixed
|
from tenacity import after_log, before_log, retry, stop_after_attempt, wait_fixed
|
||||||
|
|
||||||
from app.db.session import db_session
|
from app.db.session import SessionLocal
|
||||||
from app.tests.api.api_v1.test_login import test_get_access_token
|
from app.tests.api.api_v1.test_login import test_get_access_token
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@@ -18,10 +18,11 @@ wait_seconds = 1
|
|||||||
before=before_log(logger, logging.INFO),
|
before=before_log(logger, logging.INFO),
|
||||||
after=after_log(logger, logging.WARN),
|
after=after_log(logger, logging.WARN),
|
||||||
)
|
)
|
||||||
def init():
|
def init() -> None:
|
||||||
try:
|
try:
|
||||||
# Try to create session to check if DB is awake
|
# Try to create session to check if DB is awake
|
||||||
db_session.execute("SELECT 1")
|
db = SessionLocal()
|
||||||
|
db.execute("SELECT 1")
|
||||||
# Wait for API to be awake, run one simple tests to authenticate
|
# Wait for API to be awake, run one simple tests to authenticate
|
||||||
test_get_access_token()
|
test_get_access_token()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -29,7 +30,7 @@ def init():
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
logger.info("Initializing service")
|
logger.info("Initializing service")
|
||||||
init()
|
init()
|
||||||
logger.info("Service finished initializing")
|
logger.info("Service finished initializing")
|
||||||
|
|||||||
@@ -1,19 +1,21 @@
|
|||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import emails
|
import emails
|
||||||
import jwt
|
|
||||||
from emails.template import JinjaTemplate
|
from emails.template import JinjaTemplate
|
||||||
from jwt.exceptions import InvalidTokenError
|
from jose import jwt
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
|
||||||
password_reset_jwt_subject = "preset"
|
|
||||||
|
|
||||||
|
def send_email(
|
||||||
def send_email(email_to: str, subject_template="", html_template="", environment={}):
|
email_to: str,
|
||||||
|
subject_template: str = "",
|
||||||
|
html_template: str = "",
|
||||||
|
environment: Dict[str, Any] = {},
|
||||||
|
) -> None:
|
||||||
assert settings.EMAILS_ENABLED, "no provided configuration for email variables"
|
assert settings.EMAILS_ENABLED, "no provided configuration for email variables"
|
||||||
message = emails.Message(
|
message = emails.Message(
|
||||||
subject=JinjaTemplate(subject_template),
|
subject=JinjaTemplate(subject_template),
|
||||||
@@ -31,7 +33,7 @@ def send_email(email_to: str, subject_template="", html_template="", environment
|
|||||||
logging.info(f"send email result: {response}")
|
logging.info(f"send email result: {response}")
|
||||||
|
|
||||||
|
|
||||||
def send_test_email(email_to: str):
|
def send_test_email(email_to: str) -> None:
|
||||||
project_name = settings.PROJECT_NAME
|
project_name = settings.PROJECT_NAME
|
||||||
subject = f"{project_name} - Test email"
|
subject = f"{project_name} - Test email"
|
||||||
with open(Path(settings.EMAIL_TEMPLATES_DIR) / "test_email.html") as f:
|
with open(Path(settings.EMAIL_TEMPLATES_DIR) / "test_email.html") as f:
|
||||||
@@ -44,17 +46,13 @@ def send_test_email(email_to: str):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def send_reset_password_email(email_to: str, email: str, token: str):
|
def send_reset_password_email(email_to: str, email: str, token: str) -> None:
|
||||||
project_name = settings.PROJECT_NAME
|
project_name = settings.PROJECT_NAME
|
||||||
subject = f"{project_name} - Password recovery for user {email}"
|
subject = f"{project_name} - Password recovery for user {email}"
|
||||||
with open(Path(settings.EMAIL_TEMPLATES_DIR) / "reset_password.html") as f:
|
with open(Path(settings.EMAIL_TEMPLATES_DIR) / "reset_password.html") as f:
|
||||||
template_str = f.read()
|
template_str = f.read()
|
||||||
if hasattr(token, "decode"):
|
|
||||||
use_token = token.decode()
|
|
||||||
else:
|
|
||||||
use_token = token
|
|
||||||
server_host = settings.SERVER_HOST
|
server_host = settings.SERVER_HOST
|
||||||
link = f"{server_host}/reset-password?token={use_token}"
|
link = f"{server_host}/reset-password?token={token}"
|
||||||
send_email(
|
send_email(
|
||||||
email_to=email_to,
|
email_to=email_to,
|
||||||
subject_template=subject,
|
subject_template=subject,
|
||||||
@@ -69,7 +67,7 @@ def send_reset_password_email(email_to: str, email: str, token: str):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def send_new_account_email(email_to: str, username: str, password: str):
|
def send_new_account_email(email_to: str, username: str, password: str) -> None:
|
||||||
project_name = settings.PROJECT_NAME
|
project_name = settings.PROJECT_NAME
|
||||||
subject = f"{project_name} - New account for user {username}"
|
subject = f"{project_name} - New account for user {username}"
|
||||||
with open(Path(settings.EMAIL_TEMPLATES_DIR) / "new_account.html") as f:
|
with open(Path(settings.EMAIL_TEMPLATES_DIR) / "new_account.html") as f:
|
||||||
@@ -89,23 +87,20 @@ def send_new_account_email(email_to: str, username: str, password: str):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def generate_password_reset_token(email):
|
def generate_password_reset_token(email: str) -> str:
|
||||||
delta = timedelta(hours=settings.EMAIL_RESET_TOKEN_EXPIRE_HOURS)
|
delta = timedelta(hours=settings.EMAIL_RESET_TOKEN_EXPIRE_HOURS)
|
||||||
now = datetime.utcnow()
|
now = datetime.utcnow()
|
||||||
expires = now + delta
|
expires = now + delta
|
||||||
exp = expires.timestamp()
|
exp = expires.timestamp()
|
||||||
encoded_jwt = jwt.encode(
|
encoded_jwt = jwt.encode(
|
||||||
{"exp": exp, "nbf": now, "sub": password_reset_jwt_subject, "email": email},
|
{"exp": exp, "nbf": now, "sub": email}, settings.SECRET_KEY, algorithm="HS256",
|
||||||
settings.SECRET_KEY,
|
|
||||||
algorithm="HS256",
|
|
||||||
)
|
)
|
||||||
return encoded_jwt
|
return encoded_jwt
|
||||||
|
|
||||||
|
|
||||||
def verify_password_reset_token(token) -> Optional[str]:
|
def verify_password_reset_token(token: str) -> Optional[str]:
|
||||||
try:
|
try:
|
||||||
decoded_token = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"])
|
decoded_token = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"])
|
||||||
assert decoded_token["sub"] == password_reset_jwt_subject
|
|
||||||
return decoded_token["email"]
|
return decoded_token["email"]
|
||||||
except InvalidTokenError:
|
except jwt.JWTError:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
from raven import Client
|
from raven import Client
|
||||||
|
|
||||||
from app.core.config import settings
|
|
||||||
from app.core.celery_app import celery_app
|
from app.core.celery_app import celery_app
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
client_sentry = Client(settings.SENTRY_DSN)
|
client_sentry = Client(settings.SENTRY_DSN)
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(acks_late=True)
|
@celery_app.task(acks_late=True)
|
||||||
def test_celery(word: str):
|
def test_celery(word: str) -> str:
|
||||||
return f"test task return {word}"
|
return f"test task return {word}"
|
||||||
|
|||||||
4
{{cookiecutter.project_slug}}/backend/app/mypy.ini
Normal file
4
{{cookiecutter.project_slug}}/backend/app/mypy.ini
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
[mypy]
|
||||||
|
plugins = pydantic.mypy, sqlmypy
|
||||||
|
ignore_missing_imports = True
|
||||||
|
disallow_untyped_defs = True
|
||||||
@@ -8,7 +8,6 @@ authors = ["Admin <admin@example.com>"]
|
|||||||
python = "^3.7"
|
python = "^3.7"
|
||||||
uvicorn = "^0.11.3"
|
uvicorn = "^0.11.3"
|
||||||
fastapi = "^0.54.1"
|
fastapi = "^0.54.1"
|
||||||
pyjwt = "^1.7.1"
|
|
||||||
python-multipart = "^0.0.5"
|
python-multipart = "^0.0.5"
|
||||||
email-validator = "^1.0.5"
|
email-validator = "^1.0.5"
|
||||||
requests = "^2.23.0"
|
requests = "^2.23.0"
|
||||||
@@ -24,6 +23,7 @@ psycopg2-binary = "^2.8.5"
|
|||||||
alembic = "^1.4.2"
|
alembic = "^1.4.2"
|
||||||
sqlalchemy = "^1.3.16"
|
sqlalchemy = "^1.3.16"
|
||||||
pytest = "^5.4.1"
|
pytest = "^5.4.1"
|
||||||
|
python-jose = {extras = ["cryptography"], version = "^3.1.0"}
|
||||||
|
|
||||||
[tool.poetry.dev-dependencies]
|
[tool.poetry.dev-dependencies]
|
||||||
mypy = "^0.770"
|
mypy = "^0.770"
|
||||||
@@ -32,9 +32,15 @@ isort = "^4.3.21"
|
|||||||
autoflake = "^1.3.1"
|
autoflake = "^1.3.1"
|
||||||
flake8 = "^3.7.9"
|
flake8 = "^3.7.9"
|
||||||
pytest = "^5.4.1"
|
pytest = "^5.4.1"
|
||||||
jupyter = "^1.0.0"
|
sqlalchemy-stubs = "^0.3"
|
||||||
vulture = "^1.4"
|
pytest-cov = "^2.8.1"
|
||||||
|
|
||||||
|
[tool.isort]
|
||||||
|
multi_line_output = 3
|
||||||
|
include_trailing_comma = true
|
||||||
|
force_grid_wrap = 0
|
||||||
|
line_length = 88
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry>=0.12"]
|
requires = ["poetry>=0.12"]
|
||||||
build-backend = "poetry.masonry.api"
|
build-backend = "poetry.masonry.api"
|
||||||
|
|
||||||
|
|||||||
6
{{cookiecutter.project_slug}}/backend/app/scripts/format-imports.sh
Executable file
6
{{cookiecutter.project_slug}}/backend/app/scripts/format-imports.sh
Executable file
@@ -0,0 +1,6 @@
|
|||||||
|
#!/bin/sh -e
|
||||||
|
set -x
|
||||||
|
|
||||||
|
# Sort imports one per line, so autoflake can remove unused imports
|
||||||
|
isort --recursive --force-single-line-imports --apply app
|
||||||
|
sh ./scripts/format.sh
|
||||||
6
{{cookiecutter.project_slug}}/backend/app/scripts/format.sh
Executable file
6
{{cookiecutter.project_slug}}/backend/app/scripts/format.sh
Executable file
@@ -0,0 +1,6 @@
|
|||||||
|
#!/bin/sh -e
|
||||||
|
set -x
|
||||||
|
|
||||||
|
autoflake --remove-all-unused-imports --recursive --remove-unused-variables --in-place app --exclude=__init__.py
|
||||||
|
black app
|
||||||
|
isort --recursive --apply app
|
||||||
@@ -2,7 +2,8 @@
|
|||||||
|
|
||||||
set -x
|
set -x
|
||||||
|
|
||||||
autoflake --remove-all-unused-imports --recursive --remove-unused-variables --in-place app --exclude=__init__.py
|
mypy app
|
||||||
isort --multi-line=3 --trailing-comma --force-grid-wrap=0 --combine-as --line-width 88 --recursive --apply app
|
black app --check
|
||||||
black app
|
isort --recursive --check-only app
|
||||||
vulture app --min-confidence 70
|
vulture app --min-confidence 70
|
||||||
|
flake8
|
||||||
|
|||||||
6
{{cookiecutter.project_slug}}/backend/app/scripts/test-cov-html.sh
Executable file
6
{{cookiecutter.project_slug}}/backend/app/scripts/test-cov-html.sh
Executable file
@@ -0,0 +1,6 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
set -x
|
||||||
|
|
||||||
|
bash scripts/test.sh --cov-report=html "${@}"
|
||||||
6
{{cookiecutter.project_slug}}/backend/app/scripts/test.sh
Executable file
6
{{cookiecutter.project_slug}}/backend/app/scripts/test.sh
Executable file
@@ -0,0 +1,6 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
set -x
|
||||||
|
|
||||||
|
pytest --cov=app --cov-report=term-missing app/tests "${@}"
|
||||||
@@ -3,4 +3,4 @@ set -e
|
|||||||
|
|
||||||
python /app/app/tests_pre_start.py
|
python /app/app/tests_pre_start.py
|
||||||
|
|
||||||
pytest "$@" /app/app/tests/
|
bash ./scripts/test.sh "$@"
|
||||||
|
|||||||
@@ -10,17 +10,16 @@ RUN curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/get-
|
|||||||
|
|
||||||
# Copy poetry.lock* in case it doesn't exist in the repo
|
# Copy poetry.lock* in case it doesn't exist in the repo
|
||||||
COPY ./app/pyproject.toml ./app/poetry.lock* /app/
|
COPY ./app/pyproject.toml ./app/poetry.lock* /app/
|
||||||
RUN poetry install --no-dev --no-root
|
|
||||||
|
# Allow installing dev dependencies to run tests
|
||||||
|
ARG INSTALL_DEV=false
|
||||||
|
RUN bash -c "if [ $INSTALL_DEV == 'true' ] ; then poetry install --no-root ; else poetry install --no-root --no-dev ; fi"
|
||||||
|
|
||||||
# For development, Jupyter remote kernel, Hydrogen
|
# For development, Jupyter remote kernel, Hydrogen
|
||||||
# Using inside the container:
|
# Using inside the container:
|
||||||
# jupyter lab --ip=0.0.0.0 --allow-root --NotebookApp.custom_display_url=http://127.0.0.1:8888
|
# jupyter lab --ip=0.0.0.0 --allow-root --NotebookApp.custom_display_url=http://127.0.0.1:8888
|
||||||
ARG env=prod
|
ARG INSTALL_JUPYTER=false
|
||||||
RUN bash -c "if [ $env == 'dev' ] ; then pip install jupyterlab ; fi"
|
RUN bash -c "if [ $INSTALL_JUPYTER == 'true' ] ; then pip install jupyterlab ; fi"
|
||||||
EXPOSE 8888
|
|
||||||
|
|
||||||
COPY ./app /app
|
COPY ./app /app
|
||||||
|
|
||||||
ENV PYTHONPATH=/app
|
ENV PYTHONPATH=/app
|
||||||
|
|
||||||
EXPOSE 80
|
|
||||||
|
|||||||
@@ -10,14 +10,16 @@ RUN curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/get-
|
|||||||
|
|
||||||
# Copy poetry.lock* in case it doesn't exist in the repo
|
# Copy poetry.lock* in case it doesn't exist in the repo
|
||||||
COPY ./app/pyproject.toml ./app/poetry.lock* /app/
|
COPY ./app/pyproject.toml ./app/poetry.lock* /app/
|
||||||
RUN poetry install --no-dev --no-root
|
|
||||||
|
# Allow installing dev dependencies to run tests
|
||||||
|
ARG INSTALL_DEV=false
|
||||||
|
RUN bash -c "if [ $INSTALL_DEV == 'true' ] ; then poetry install --no-root ; else poetry install --no-root --no-dev ; fi"
|
||||||
|
|
||||||
# For development, Jupyter remote kernel, Hydrogen
|
# For development, Jupyter remote kernel, Hydrogen
|
||||||
# Using inside the container:
|
# Using inside the container:
|
||||||
# jupyter lab --ip=0.0.0.0 --allow-root --NotebookApp.custom_display_url=http://127.0.0.1:8888
|
# jupyter lab --ip=0.0.0.0 --allow-root --NotebookApp.custom_display_url=http://127.0.0.1:8888
|
||||||
ARG env=prod
|
ARG INSTALL_JUPYTER=false
|
||||||
RUN bash -c "if [ $env == 'dev' ] ; then pip install jupyterlab ; fi"
|
RUN bash -c "if [ $INSTALL_JUPYTER == 'true' ] ; then pip install jupyterlab ; fi"
|
||||||
EXPOSE 8888
|
|
||||||
|
|
||||||
ENV C_FORCE_ROOT=1
|
ENV C_FORCE_ROOT=1
|
||||||
|
|
||||||
|
|||||||
@@ -37,7 +37,8 @@ services:
|
|||||||
context: ./backend
|
context: ./backend
|
||||||
dockerfile: backend.dockerfile
|
dockerfile: backend.dockerfile
|
||||||
args:
|
args:
|
||||||
env: dev
|
INSTALL_DEV: ${INSTALL_DEV-true}
|
||||||
|
INSTALL_JUPYTER: ${INSTALL_JUPYTER-true}
|
||||||
# command: bash -c "while true; do sleep 1; done" # Infinite loop to keep container live doing nothing
|
# command: bash -c "while true; do sleep 1; done" # Infinite loop to keep container live doing nothing
|
||||||
command: /start-reload.sh
|
command: /start-reload.sh
|
||||||
labels:
|
labels:
|
||||||
@@ -57,7 +58,8 @@ services:
|
|||||||
context: ./backend
|
context: ./backend
|
||||||
dockerfile: celeryworker.dockerfile
|
dockerfile: celeryworker.dockerfile
|
||||||
args:
|
args:
|
||||||
env: dev
|
INSTALL_DEV: ${INSTALL_DEV-true}
|
||||||
|
INSTALL_JUPYTER: ${INSTALL_JUPYTER-true}
|
||||||
|
|
||||||
frontend:
|
frontend:
|
||||||
build:
|
build:
|
||||||
|
|||||||
@@ -115,6 +115,8 @@ services:
|
|||||||
build:
|
build:
|
||||||
context: ./backend
|
context: ./backend
|
||||||
dockerfile: backend.dockerfile
|
dockerfile: backend.dockerfile
|
||||||
|
args:
|
||||||
|
INSTALL_DEV: ${INSTALL_DEV-false}
|
||||||
deploy:
|
deploy:
|
||||||
labels:
|
labels:
|
||||||
- traefik.frontend.rule=PathPrefix:/api,/docs,/redoc
|
- traefik.frontend.rule=PathPrefix:/api,/docs,/redoc
|
||||||
@@ -137,6 +139,8 @@ services:
|
|||||||
build:
|
build:
|
||||||
context: ./backend
|
context: ./backend
|
||||||
dockerfile: celeryworker.dockerfile
|
dockerfile: celeryworker.dockerfile
|
||||||
|
args:
|
||||||
|
INSTALL_DEV: ${INSTALL_DEV-false}
|
||||||
|
|
||||||
frontend:
|
frontend:
|
||||||
image: '${DOCKER_IMAGE_FRONTEND}:${TAG-latest}'
|
image: '${DOCKER_IMAGE_FRONTEND}:${TAG-latest}'
|
||||||
|
|||||||
@@ -5,6 +5,6 @@ set -e
|
|||||||
|
|
||||||
TAG=${TAG} \
|
TAG=${TAG} \
|
||||||
FRONTEND_ENV=${FRONTEND_ENV-production} \
|
FRONTEND_ENV=${FRONTEND_ENV-production} \
|
||||||
. ./scripts/build.sh
|
sh ./scripts/build.sh
|
||||||
|
|
||||||
docker-compose -f docker-compose.yml push
|
docker-compose -f docker-compose.yml push
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ set -e
|
|||||||
DOMAIN=backend \
|
DOMAIN=backend \
|
||||||
SMTP_HOST="" \
|
SMTP_HOST="" \
|
||||||
TRAEFIK_PUBLIC_NETWORK_IS_EXTERNAL=false \
|
TRAEFIK_PUBLIC_NETWORK_IS_EXTERNAL=false \
|
||||||
|
INSTALL_DEV=true \
|
||||||
docker-compose \
|
docker-compose \
|
||||||
-f docker-compose.yml \
|
-f docker-compose.yml \
|
||||||
config > docker-stack.yml
|
config > docker-stack.yml
|
||||||
|
|||||||
Reference in New Issue
Block a user