mirror of
https://gitee.com/dify_ai/dify.git
synced 2025-12-06 19:42:42 +08:00
Compare commits
84 Commits
refactor/i
...
build/add-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c5b07b999c | ||
|
|
c9f785e00f | ||
|
|
0e8ab0588f | ||
|
|
0ebe198ff1 | ||
|
|
438ad8148b | ||
|
|
a60133bfb3 | ||
|
|
98b3e37144 | ||
|
|
6e23903c63 | ||
|
|
574c4a264f | ||
|
|
dd5ffaf058 | ||
|
|
0b16270b88 | ||
|
|
f562a88249 | ||
|
|
59f8d116af | ||
|
|
a5558f8fcc | ||
|
|
823ae03a08 | ||
|
|
f8c958a409 | ||
|
|
25785d8c3f | ||
|
|
35d3da9697 | ||
|
|
d3e9930235 | ||
|
|
1ccca7cc68 | ||
|
|
12a9e2972a | ||
|
|
444c1f170a | ||
|
|
3cb2fb8250 | ||
|
|
1e8457441d | ||
|
|
5a9448245b | ||
|
|
eafe5a9d8f | ||
|
|
d45d90e8ae | ||
|
|
42a9374e71 | ||
|
|
82a775eca3 | ||
|
|
1dae1a71fc | ||
|
|
ac0fed6402 | ||
|
|
fb656d480e | ||
|
|
2b7341af57 | ||
|
|
ce1f9d935d | ||
|
|
bdadca1a65 | ||
|
|
d7b4d0756e | ||
|
|
1279e27825 | ||
|
|
d92e3bd620 | ||
|
|
7f583ec1ac | ||
|
|
7962101e5e | ||
|
|
ae254f0a10 | ||
|
|
68e0b0ac84 | ||
|
|
5f21d13572 | ||
|
|
233bffdb7d | ||
|
|
bf9349c4dc | ||
|
|
4847548779 | ||
|
|
cb245b5435 | ||
|
|
249b897872 | ||
|
|
08c731fd84 | ||
|
|
302f4407f6 | ||
|
|
de5dfd99f6 | ||
|
|
acb22f0fde | ||
|
|
d1505b15c4 | ||
|
|
cca2e7876d | ||
|
|
2c4d8dbe9b | ||
|
|
9305ad2102 | ||
|
|
7a98dab6a4 | ||
|
|
971defbbbd | ||
|
|
6b0de08157 | ||
|
|
87c1de66f2 | ||
|
|
2aa171c348 | ||
|
|
6452342222 | ||
|
|
da204c131d | ||
|
|
9369cc44e6 | ||
|
|
38bca6731c | ||
|
|
2adab7f71a | ||
|
|
be96f6e62d | ||
|
|
8b5ea39916 | ||
|
|
1024fc623e | ||
|
|
8ab05d4c36 | ||
|
|
0c9e79cd67 | ||
|
|
2ed6bb86c1 | ||
|
|
61da0f08dd | ||
|
|
1432c268a8 | ||
|
|
ec6a03afdd | ||
|
|
bf371a6e5d | ||
|
|
b28cf68097 | ||
|
|
a0af7a51ed | ||
|
|
dfa3ef0564 | ||
|
|
0066531266 | ||
|
|
53a7cb0e9d | ||
|
|
86739bea8b | ||
|
|
ab127ba92e | ||
|
|
6a2a9460e9 |
10
.github/dependabot.yml
vendored
Normal file
10
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: npm
|
||||
directory: /web
|
||||
schedule:
|
||||
interval: weekly
|
||||
assignees:
|
||||
- akarachen
|
||||
labels:
|
||||
- dependencies
|
||||
33
.github/workflows/api-tests.yml
vendored
33
.github/workflows/api-tests.yml
vendored
@@ -7,6 +7,7 @@ on:
|
||||
paths:
|
||||
- api/**
|
||||
- docker/**
|
||||
- .github/workflows/api-tests.yml
|
||||
|
||||
concurrency:
|
||||
group: api-tests-${{ github.head_ref || github.run_id }}
|
||||
@@ -27,16 +28,15 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Poetry
|
||||
uses: abatilo/actions-poetry@v3
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache-dependency-path: |
|
||||
api/pyproject.toml
|
||||
api/poetry.lock
|
||||
|
||||
- name: Install Poetry
|
||||
uses: abatilo/actions-poetry@v3
|
||||
cache: poetry
|
||||
cache-dependency-path: api/poetry.lock
|
||||
|
||||
- name: Check Poetry lockfile
|
||||
run: |
|
||||
@@ -67,7 +67,7 @@ jobs:
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
- name: Set up Sandbox
|
||||
uses: hoverkraft-tech/compose-action@v2.0.0
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
@@ -77,22 +77,3 @@ jobs:
|
||||
|
||||
- name: Run Workflow
|
||||
run: poetry run -C api bash dev/pytest/pytest_workflow.sh
|
||||
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase)
|
||||
uses: hoverkraft-tech/compose-action@v2.0.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.yaml
|
||||
services: |
|
||||
weaviate
|
||||
qdrant
|
||||
couchbase-server
|
||||
etcd
|
||||
minio
|
||||
milvus-standalone
|
||||
pgvecto-rs
|
||||
pgvector
|
||||
chroma
|
||||
elasticsearch
|
||||
- name: Test Vector Stores
|
||||
run: poetry run -C api bash dev/pytest/pytest_vdb.sh
|
||||
|
||||
3
.github/workflows/db-migration-test.yml
vendored
3
.github/workflows/db-migration-test.yml
vendored
@@ -6,6 +6,7 @@ on:
|
||||
- main
|
||||
paths:
|
||||
- api/migrations/**
|
||||
- .github/workflows/db-migration-test.yml
|
||||
|
||||
concurrency:
|
||||
group: db-migration-test-${{ github.ref }}
|
||||
@@ -43,7 +44,7 @@ jobs:
|
||||
cp middleware.env.example middleware.env
|
||||
|
||||
- name: Set up Middlewares
|
||||
uses: hoverkraft-tech/compose-action@v2.0.0
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
|
||||
8
.github/workflows/style.yml
vendored
8
.github/workflows/style.yml
vendored
@@ -24,16 +24,16 @@ jobs:
|
||||
with:
|
||||
files: api/**
|
||||
|
||||
- name: Install Poetry
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: abatilo/actions-poetry@v3
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install Poetry
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: abatilo/actions-poetry@v3
|
||||
|
||||
- name: Python dependencies
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: poetry install -C api --only lint
|
||||
|
||||
75
.github/workflows/vdb-tests.yml
vendored
Normal file
75
.github/workflows/vdb-tests.yml
vendored
Normal file
@@ -0,0 +1,75 @@
|
||||
name: Run VDB Tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- api/core/rag/datasource/**
|
||||
- docker/**
|
||||
- .github/workflows/vdb-tests.yml
|
||||
|
||||
concurrency:
|
||||
group: vdb-tests-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: VDB Tests
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
- "3.10"
|
||||
- "3.11"
|
||||
- "3.12"
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Poetry
|
||||
uses: abatilo/actions-poetry@v3
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: poetry
|
||||
cache-dependency-path: api/poetry.lock
|
||||
|
||||
- name: Check Poetry lockfile
|
||||
run: |
|
||||
poetry check -C api --lock
|
||||
poetry show -C api
|
||||
|
||||
- name: Install dependencies
|
||||
run: poetry install -C api --with dev
|
||||
|
||||
- name: Set up dotenvs
|
||||
run: |
|
||||
cp docker/.env.example docker/.env
|
||||
cp docker/middleware.env.example docker/middleware.env
|
||||
|
||||
- name: Expose Service Ports
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase)
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.yaml
|
||||
services: |
|
||||
weaviate
|
||||
qdrant
|
||||
couchbase-server
|
||||
etcd
|
||||
minio
|
||||
milvus-standalone
|
||||
pgvecto-rs
|
||||
pgvector
|
||||
chroma
|
||||
elasticsearch
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: poetry run -C api bash dev/pytest/pytest_vdb.sh
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -175,6 +175,7 @@ docker/volumes/pgvector/data/*
|
||||
docker/volumes/pgvecto_rs/data/*
|
||||
docker/volumes/couchbase/*
|
||||
docker/volumes/oceanbase/*
|
||||
!docker/volumes/oceanbase/init.d
|
||||
|
||||
docker/nginx/conf.d/default.conf
|
||||
docker/nginx/ssl/*
|
||||
|
||||
@@ -81,7 +81,7 @@ Dify requires the following dependencies to build, make sure they're installed o
|
||||
|
||||
Dify is composed of a backend and a frontend. Navigate to the backend directory by `cd api/`, then follow the [Backend README](api/README.md) to install it. In a separate terminal, navigate to the frontend directory by `cd web/`, then follow the [Frontend README](web/README.md) to install.
|
||||
|
||||
Check the [installation FAQ](https://docs.dify.ai/learn-more/faq/self-host-faq) for a list of common issues and steps to troubleshoot.
|
||||
Check the [installation FAQ](https://docs.dify.ai/learn-more/faq/install-faq) for a list of common issues and steps to troubleshoot.
|
||||
|
||||
### 5. Visit dify in your browser
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ Dify yêu cầu các phụ thuộc sau để build, hãy đảm bảo chúng đ
|
||||
|
||||
Dify bao gồm một backend và một frontend. Đi đến thư mục backend bằng lệnh `cd api/`, sau đó làm theo hướng dẫn trong [README của Backend](api/README.md) để cài đặt. Trong một terminal khác, đi đến thư mục frontend bằng lệnh `cd web/`, sau đó làm theo hướng dẫn trong [README của Frontend](web/README.md) để cài đặt.
|
||||
|
||||
Kiểm tra [FAQ về cài đặt](https://docs.dify.ai/learn-more/faq/self-host-faq) để xem danh sách các vấn đề thường gặp và các bước khắc phục.
|
||||
Kiểm tra [FAQ về cài đặt](https://docs.dify.ai/learn-more/faq/install-faq) để xem danh sách các vấn đề thường gặp và các bước khắc phục.
|
||||
|
||||
### 5. Truy cập Dify trong trình duyệt của bạn
|
||||
|
||||
|
||||
73
README.md
73
README.md
@@ -46,45 +46,18 @@
|
||||
</p>
|
||||
|
||||
|
||||
## Table of Content
|
||||
0. [Quick-Start🚀](https://github.com/langgenius/dify?tab=readme-ov-file#quick-start)
|
||||
|
||||
1. [Intro📖](https://github.com/langgenius/dify?tab=readme-ov-file#intro)
|
||||
|
||||
2. [How to use🔧](https://github.com/langgenius/dify?tab=readme-ov-file#using-dify)
|
||||
|
||||
3. [Stay Ahead🏃](https://github.com/langgenius/dify?tab=readme-ov-file#staying-ahead)
|
||||
|
||||
4. [Next Steps🏹](https://github.com/langgenius/dify?tab=readme-ov-file#next-steps)
|
||||
|
||||
5. [Contributing💪](https://github.com/langgenius/dify?tab=readme-ov-file#contributing)
|
||||
|
||||
6. [Community and Contact🏠](https://github.com/langgenius/dify?tab=readme-ov-file#community--contact)
|
||||
|
||||
7. [Star-History📈](https://github.com/langgenius/dify?tab=readme-ov-file#star-history)
|
||||
|
||||
8. [Security🔒](https://github.com/langgenius/dify?tab=readme-ov-file#security-disclosure)
|
||||
|
||||
9. [License🤝](https://github.com/langgenius/dify?tab=readme-ov-file#license)
|
||||
|
||||
> Make sure you read through this README before you start utilizing Dify😊
|
||||
|
||||
Dify is an open-source LLM app development platform. Its intuitive interface combines agentic AI workflow, RAG pipeline, agent capabilities, model management, observability features and more, letting you quickly go from prototype to production.
|
||||
|
||||
## Quick start
|
||||
The quickest way to deploy Dify locally is to run our [docker-compose.yml](https://github.com/langgenius/dify/blob/main/docker/docker-compose.yaml). Follow the instructions to start in 5 minutes.
|
||||
|
||||
> Before installing Dify, make sure your machine meets the following minimum system requirements:
|
||||
>
|
||||
>- CPU >= 2 Core
|
||||
>- RAM >= 4 GiB
|
||||
>- Docker and Docker Compose Installed
|
||||
|
||||
</br>
|
||||
|
||||
Run the following command in your terminal to clone the whole repo.
|
||||
```bash
|
||||
git clone https://github.com/langgenius/dify.git
|
||||
```
|
||||
After cloning,run the following command one by one.
|
||||
The easiest way to start the Dify server is through [docker compose](docker/docker-compose.yaml). Before running Dify with the following commands, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine:
|
||||
|
||||
```bash
|
||||
cd dify
|
||||
cd docker
|
||||
@@ -92,13 +65,14 @@ cp .env.example .env
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization process. You will be asked to setup an admin account.
|
||||
For more info of quick setup, check [here](https://docs.dify.ai/getting-started/install-self-hosted/docker-compose)
|
||||
After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization process.
|
||||
|
||||
## Intro
|
||||
Dify is an open-source LLM app development platform. Its intuitive interface combines AI workflow, RAG pipeline, agent capabilities, model management, observability features and more, letting you quickly go from prototype to production. Here's a list of the core features:
|
||||
</br> </br>
|
||||
#### Seeking help
|
||||
Please refer to our [FAQ](https://docs.dify.ai/getting-started/install-self-hosted/faqs) if you encounter problems setting up Dify. Reach out to [the community and us](#community--contact) if you are still having issues.
|
||||
|
||||
> If you'd like to contribute to Dify or do additional development, refer to our [guide to deploying from source code](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code)
|
||||
|
||||
## Key features
|
||||
**1. Workflow**:
|
||||
Build and test powerful AI workflows on a visual canvas, leveraging all the following features and beyond.
|
||||
|
||||
@@ -149,20 +123,8 @@ Star Dify on GitHub and be instantly notified of new releases.
|
||||
|
||||

|
||||
|
||||
## Next steps
|
||||
|
||||
Go to [quick-start](https://github.com/langgenius/dify?tab=readme-ov-file#quick-start) to setup your Dify or setup by source code.
|
||||
|
||||
#### If you......
|
||||
If you forget your admin account, you can refer to this [guide](https://docs.dify.ai/getting-started/install-self-hosted/faqs#id-4.-how-to-reset-the-password-of-the-admin-account) to reset the password.
|
||||
|
||||
> Use docker compose up without "-d" to enable logs printing out in your terminal. This might be useful if you have encountered unknow problems when using Dify.
|
||||
|
||||
If you encountered system error and would like to acquire help in Github issues, make sure you always paste logs of the error in the request to accerate the conversation. Go to [Community & contact](https://github.com/langgenius/dify?tab=readme-ov-file#community--contact) for more information.
|
||||
|
||||
> Please read the [Dify Documentation](https://docs.dify.ai/) for detailed how-to-use guidance. Most of the potential problems are explained in the doc.
|
||||
|
||||
> If you'd like to contribute to Dify or make additional development, refer to our [guide to deploying from source code](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code)
|
||||
## Advanced Setup
|
||||
|
||||
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
||||
|
||||
@@ -190,19 +152,18 @@ At the same time, please consider supporting Dify by sharing it on social media
|
||||
|
||||
> We are looking for contributors to help with translating Dify to languages other than Mandarin or English. If you are interested in helping, please see the [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) for more information, and leave us a comment in the `global-users` channel of our [Discord Community Server](https://discord.gg/8Tpq4AcN9c).
|
||||
|
||||
**Contributors**
|
||||
|
||||
<a href="https://github.com/langgenius/dify/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=langgenius/dify" />
|
||||
</a>
|
||||
|
||||
## Community & contact
|
||||
|
||||
* [Github Discussion](https://github.com/langgenius/dify/discussions). Best for: sharing feedback and asking questions.
|
||||
* [GitHub Issues](https://github.com/langgenius/dify/issues). Best for: bugs you encounter using Dify.AI, and feature proposals. See our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
* [Discord](https://discord.gg/FngNHpbcY7). Best for: sharing your applications and hanging out with the community.
|
||||
* [X(Twitter)](https://twitter.com/dify_ai). Best for: sharing your applications and hanging out with the community.
|
||||
* Make sure a log, if possible, is attached to an error reported to maximize solution efficiency.
|
||||
|
||||
**Contributors**
|
||||
|
||||
<a href="https://github.com/langgenius/dify/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=langgenius/dify" />
|
||||
</a>
|
||||
|
||||
## Star history
|
||||
|
||||
|
||||
@@ -120,7 +120,8 @@ SUPABASE_URL=your-server-url
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
|
||||
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash
|
||||
|
||||
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase
|
||||
VECTOR_STORE=weaviate
|
||||
|
||||
# Weaviate configuration
|
||||
@@ -263,14 +264,20 @@ VIKINGDB_SCHEMA=http
|
||||
VIKINGDB_CONNECTION_TIMEOUT=30
|
||||
VIKINGDB_SOCKET_TIMEOUT=30
|
||||
|
||||
# Lindorm configuration
|
||||
LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070
|
||||
LINDORM_USERNAME=admin
|
||||
LINDORM_PASSWORD=admin
|
||||
|
||||
# OceanBase Vector configuration
|
||||
OCEANBASE_VECTOR_HOST=127.0.0.1
|
||||
OCEANBASE_VECTOR_PORT=2881
|
||||
OCEANBASE_VECTOR_USER=root@test
|
||||
OCEANBASE_VECTOR_PASSWORD=
|
||||
OCEANBASE_VECTOR_PASSWORD=difyai123456
|
||||
OCEANBASE_VECTOR_DATABASE=test
|
||||
OCEANBASE_MEMORY_LIMIT=6G
|
||||
|
||||
|
||||
# Upload configuration
|
||||
UPLOAD_FILE_SIZE_LIMIT=15
|
||||
UPLOAD_FILE_BATCH_LIMIT=5
|
||||
@@ -313,13 +320,21 @@ ETL_TYPE=dify
|
||||
UNSTRUCTURED_API_URL=
|
||||
UNSTRUCTURED_API_KEY=
|
||||
|
||||
#ssrf
|
||||
SSRF_PROXY_HTTP_URL=
|
||||
SSRF_PROXY_HTTPS_URL=
|
||||
SSRF_DEFAULT_MAX_RETRIES=3
|
||||
SSRF_DEFAULT_TIME_OUT=5
|
||||
SSRF_DEFAULT_CONNECT_TIME_OUT=5
|
||||
SSRF_DEFAULT_READ_TIME_OUT=5
|
||||
SSRF_DEFAULT_WRITE_TIME_OUT=5
|
||||
|
||||
BATCH_UPLOAD_LIMIT=10
|
||||
KEYWORD_DATA_SOURCE_TYPE=database
|
||||
|
||||
# Workflow file upload limit
|
||||
WORKFLOW_FILE_UPLOAD_LIMIT=10
|
||||
|
||||
# CODE EXECUTION CONFIGURATION
|
||||
CODE_EXECUTION_ENDPOINT=http://127.0.0.1:8194
|
||||
CODE_EXECUTION_API_KEY=dify-sandbox
|
||||
|
||||
@@ -55,12 +55,7 @@ RUN apt-get update \
|
||||
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
|
||||
&& apt-get update \
|
||||
# For Security
|
||||
&& apt-get install -y --no-install-recommends expat=2.6.3-2 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-6 libsqlite3-0=3.46.1-1 \
|
||||
&& if [ "$(dpkg --print-architecture)" = "amd64" ]; then \
|
||||
apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1+b1; \
|
||||
else \
|
||||
apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1; \
|
||||
fi \
|
||||
&& apt-get install -y --no-install-recommends expat=2.6.3-2 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-6 libsqlite3-0=3.46.1-1 zlib1g=1:1.3.dfsg+really1.3.1-1+b1 \
|
||||
# install a chinese font to support the use of tools like matplotlib
|
||||
&& apt-get install -y fonts-noto-cjk \
|
||||
&& apt-get autoremove -y \
|
||||
|
||||
@@ -76,13 +76,13 @@
|
||||
1. Install dependencies for both the backend and the test environment
|
||||
|
||||
```bash
|
||||
poetry install --with dev
|
||||
poetry install -C api --with dev
|
||||
```
|
||||
|
||||
2. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`
|
||||
|
||||
```bash
|
||||
cd ../
|
||||
poetry run -C api bash dev/pytest/pytest_all_tests.sh
|
||||
```
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
if os.environ.get("DEBUG", "false").lower() != "true":
|
||||
if not dify_config.DEBUG:
|
||||
from gevent import monkey
|
||||
|
||||
monkey.patch_all()
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import os
|
||||
|
||||
if os.environ.get("DEBUG", "false").lower() != "true":
|
||||
from configs import dify_config
|
||||
|
||||
if not dify_config.DEBUG:
|
||||
from gevent import monkey
|
||||
|
||||
monkey.patch_all()
|
||||
|
||||
@@ -109,7 +109,7 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
||||
)
|
||||
|
||||
CODE_MAX_PRECISION: PositiveInt = Field(
|
||||
description="mMaximum number of decimal places for floating-point numbers in code execution",
|
||||
description="Maximum number of decimal places for floating-point numbers in code execution",
|
||||
default=20,
|
||||
)
|
||||
|
||||
@@ -216,6 +216,11 @@ class FileUploadConfig(BaseSettings):
|
||||
default=20,
|
||||
)
|
||||
|
||||
WORKFLOW_FILE_UPLOAD_LIMIT: PositiveInt = Field(
|
||||
description="Maximum number of files allowed in a workflow upload operation",
|
||||
default=10,
|
||||
)
|
||||
|
||||
|
||||
class HttpConfig(BaseSettings):
|
||||
"""
|
||||
@@ -271,6 +276,16 @@ class HttpConfig(BaseSettings):
|
||||
default=1 * 1024 * 1024,
|
||||
)
|
||||
|
||||
SSRF_DEFAULT_MAX_RETRIES: PositiveInt = Field(
|
||||
description="Maximum number of retries for network requests (SSRF)",
|
||||
default=3,
|
||||
)
|
||||
|
||||
SSRF_PROXY_ALL_URL: Optional[str] = Field(
|
||||
description="Proxy URL for HTTP or HTTPS requests to prevent Server-Side Request Forgery (SSRF)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
SSRF_PROXY_HTTP_URL: Optional[str] = Field(
|
||||
description="Proxy URL for HTTP requests to prevent Server-Side Request Forgery (SSRF)",
|
||||
default=None,
|
||||
@@ -281,6 +296,26 @@ class HttpConfig(BaseSettings):
|
||||
default=None,
|
||||
)
|
||||
|
||||
SSRF_DEFAULT_TIME_OUT: PositiveFloat = Field(
|
||||
description="The default timeout period used for network requests (SSRF)",
|
||||
default=5,
|
||||
)
|
||||
|
||||
SSRF_DEFAULT_CONNECT_TIME_OUT: PositiveFloat = Field(
|
||||
description="The default connect timeout period used for network requests (SSRF)",
|
||||
default=5,
|
||||
)
|
||||
|
||||
SSRF_DEFAULT_READ_TIME_OUT: PositiveFloat = Field(
|
||||
description="The default read timeout period used for network requests (SSRF)",
|
||||
default=5,
|
||||
)
|
||||
|
||||
SSRF_DEFAULT_WRITE_TIME_OUT: PositiveFloat = Field(
|
||||
description="The default write timeout period used for network requests (SSRF)",
|
||||
default=5,
|
||||
)
|
||||
|
||||
RESPECT_XFORWARD_HEADERS_ENABLED: bool = Field(
|
||||
description="Enable or disable the X-Forwarded-For Proxy Fix middleware from Werkzeug"
|
||||
" to respect X-* headers to redirect clients",
|
||||
|
||||
@@ -20,6 +20,7 @@ from configs.middleware.vdb.baidu_vector_config import BaiduVectorDBConfig
|
||||
from configs.middleware.vdb.chroma_config import ChromaConfig
|
||||
from configs.middleware.vdb.couchbase_config import CouchbaseConfig
|
||||
from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
|
||||
from configs.middleware.vdb.lindorm_config import LindormConfig
|
||||
from configs.middleware.vdb.milvus_config import MilvusConfig
|
||||
from configs.middleware.vdb.myscale_config import MyScaleConfig
|
||||
from configs.middleware.vdb.oceanbase_config import OceanBaseVectorConfig
|
||||
@@ -259,6 +260,7 @@ class MiddlewareConfig(
|
||||
VikingDBConfig,
|
||||
UpstashConfig,
|
||||
TidbOnQdrantConfig,
|
||||
LindormConfig,
|
||||
OceanBaseVectorConfig,
|
||||
BaiduVectorDBConfig,
|
||||
):
|
||||
|
||||
23
api/configs/middleware/vdb/lindorm_config.py
Normal file
23
api/configs/middleware/vdb/lindorm_config.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class LindormConfig(BaseSettings):
|
||||
"""
|
||||
Lindorm configs
|
||||
"""
|
||||
|
||||
LINDORM_URL: Optional[str] = Field(
|
||||
description="Lindorm url",
|
||||
default=None,
|
||||
)
|
||||
LINDORM_USERNAME: Optional[str] = Field(
|
||||
description="Lindorm user",
|
||||
default=None,
|
||||
)
|
||||
LINDORM_PASSWORD: Optional[str] = Field(
|
||||
description="Lindorm password",
|
||||
default=None,
|
||||
)
|
||||
@@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.10.2",
|
||||
default="0.11.0",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
||||
24
api/controllers/common/fields.py
Normal file
24
api/controllers/common/fields.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from flask_restful import fields
|
||||
|
||||
parameters__system_parameters = {
|
||||
"image_file_size_limit": fields.Integer,
|
||||
"video_file_size_limit": fields.Integer,
|
||||
"audio_file_size_limit": fields.Integer,
|
||||
"file_size_limit": fields.Integer,
|
||||
"workflow_file_upload_limit": fields.Integer,
|
||||
}
|
||||
|
||||
parameters_fields = {
|
||||
"opening_statement": fields.String,
|
||||
"suggested_questions": fields.Raw,
|
||||
"suggested_questions_after_answer": fields.Raw,
|
||||
"speech_to_text": fields.Raw,
|
||||
"text_to_speech": fields.Raw,
|
||||
"retriever_resource": fields.Raw,
|
||||
"annotation_reply": fields.Raw,
|
||||
"more_like_this": fields.Raw,
|
||||
"user_input_form": fields.Raw,
|
||||
"sensitive_word_avoidance": fields.Raw,
|
||||
"file_upload": fields.Raw,
|
||||
"system_parameters": fields.Nested(parameters__system_parameters),
|
||||
}
|
||||
@@ -2,11 +2,15 @@ import mimetypes
|
||||
import os
|
||||
import re
|
||||
import urllib.parse
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
class FileInfo(BaseModel):
|
||||
filename: str
|
||||
@@ -56,3 +60,38 @@ def guess_file_info_from_response(response: httpx.Response):
|
||||
mimetype=mimetype,
|
||||
size=int(response.headers.get("Content-Length", -1)),
|
||||
)
|
||||
|
||||
|
||||
def get_parameters_from_feature_dict(*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]):
|
||||
return {
|
||||
"opening_statement": features_dict.get("opening_statement"),
|
||||
"suggested_questions": features_dict.get("suggested_questions", []),
|
||||
"suggested_questions_after_answer": features_dict.get("suggested_questions_after_answer", {"enabled": False}),
|
||||
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
|
||||
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
|
||||
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
|
||||
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
|
||||
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
|
||||
"user_input_form": user_input_form,
|
||||
"sensitive_word_avoidance": features_dict.get(
|
||||
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
|
||||
),
|
||||
"file_upload": features_dict.get(
|
||||
"file_upload",
|
||||
{
|
||||
"image": {
|
||||
"enabled": False,
|
||||
"number_limits": 3,
|
||||
"detail": "high",
|
||||
"transfer_methods": ["remote_url", "local_file"],
|
||||
}
|
||||
},
|
||||
),
|
||||
"system_parameters": {
|
||||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
||||
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -456,7 +456,7 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@@ -620,6 +620,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
case (
|
||||
VectorType.MILVUS
|
||||
| VectorType.RELYT
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
| VectorType.TENCENT
|
||||
@@ -640,6 +641,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
| VectorType.ELASTICSEARCH
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.TIDB_ON_QDRANT
|
||||
| VectorType.LINDORM
|
||||
| VectorType.COUCHBASE
|
||||
):
|
||||
return {
|
||||
@@ -682,6 +684,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
| VectorType.ELASTICSEARCH
|
||||
| VectorType.COUCHBASE
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.LINDORM
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
|
||||
@@ -945,7 +945,7 @@ class DocumentRetryApi(DocumentResource):
|
||||
raise DocumentAlreadyFinishedError()
|
||||
retry_documents.append(document)
|
||||
except Exception as e:
|
||||
logging.error(f"Document {document_id} retry failed: {str(e)}")
|
||||
logging.exception(f"Document {document_id} retry failed: {str(e)}")
|
||||
continue
|
||||
# retry document
|
||||
DocumentService.retry_document(dataset_id, retry_documents)
|
||||
|
||||
@@ -62,3 +62,27 @@ class EmailSendIpLimitError(BaseHTTPException):
|
||||
error_code = "email_send_ip_limit"
|
||||
description = "Too many emails have been sent from this IP address recently. Please try again later."
|
||||
code = 429
|
||||
|
||||
|
||||
class FileTooLargeError(BaseHTTPException):
|
||||
error_code = "file_too_large"
|
||||
description = "File size exceeded. {message}"
|
||||
code = 413
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(BaseHTTPException):
|
||||
error_code = "unsupported_file_type"
|
||||
description = "File type not allowed."
|
||||
code = 415
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = "too_many_files"
|
||||
description = "Only one file is allowed."
|
||||
code = 400
|
||||
|
||||
|
||||
class NoFileUploadedError(BaseHTTPException):
|
||||
error_code = "no_file_uploaded"
|
||||
description = "Please upload your file."
|
||||
code = 400
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from flask_restful import fields, marshal_with
|
||||
from flask_restful import marshal_with
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common import fields
|
||||
from controllers.common import helpers as controller_helpers
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import AppUnavailableError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
@@ -11,43 +12,14 @@ from services.app_service import AppService
|
||||
class AppParameterApi(InstalledAppResource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
variable_fields = {
|
||||
"key": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"type": fields.String,
|
||||
"default": fields.String,
|
||||
"max_length": fields.Integer,
|
||||
"options": fields.List(fields.String),
|
||||
}
|
||||
|
||||
system_parameters_fields = {
|
||||
"image_file_size_limit": fields.Integer,
|
||||
"video_file_size_limit": fields.Integer,
|
||||
"audio_file_size_limit": fields.Integer,
|
||||
"file_size_limit": fields.Integer,
|
||||
}
|
||||
|
||||
parameters_fields = {
|
||||
"opening_statement": fields.String,
|
||||
"suggested_questions": fields.Raw,
|
||||
"suggested_questions_after_answer": fields.Raw,
|
||||
"speech_to_text": fields.Raw,
|
||||
"text_to_speech": fields.Raw,
|
||||
"retriever_resource": fields.Raw,
|
||||
"annotation_reply": fields.Raw,
|
||||
"more_like_this": fields.Raw,
|
||||
"user_input_form": fields.Raw,
|
||||
"sensitive_word_avoidance": fields.Raw,
|
||||
"file_upload": fields.Raw,
|
||||
"system_parameters": fields.Nested(system_parameters_fields),
|
||||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
@marshal_with(fields.parameters_fields)
|
||||
def get(self, installed_app: InstalledApp):
|
||||
"""Retrieve app parameters."""
|
||||
app_model = installed_app.app
|
||||
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
@@ -57,43 +29,16 @@ class AppParameterApi(InstalledAppResource):
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app_model.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
return {
|
||||
"opening_statement": features_dict.get("opening_statement"),
|
||||
"suggested_questions": features_dict.get("suggested_questions", []),
|
||||
"suggested_questions_after_answer": features_dict.get(
|
||||
"suggested_questions_after_answer", {"enabled": False}
|
||||
),
|
||||
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
|
||||
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
|
||||
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
|
||||
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
|
||||
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
|
||||
"user_input_form": user_input_form,
|
||||
"sensitive_word_avoidance": features_dict.get(
|
||||
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
|
||||
),
|
||||
"file_upload": features_dict.get(
|
||||
"file_upload",
|
||||
{
|
||||
"image": {
|
||||
"enabled": False,
|
||||
"number_limits": 3,
|
||||
"detail": "high",
|
||||
"transfer_methods": ["remote_url", "local_file"],
|
||||
}
|
||||
},
|
||||
),
|
||||
"system_parameters": {
|
||||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
||||
},
|
||||
}
|
||||
return controller_helpers.get_parameters_from_feature_dict(
|
||||
features_dict=features_dict, user_input_form=user_input_form
|
||||
)
|
||||
|
||||
|
||||
class ExploreAppMetaApi(InstalledAppResource):
|
||||
|
||||
@@ -15,7 +15,7 @@ from fields.file_fields import file_fields, upload_config_fields
|
||||
from libs.login import login_required
|
||||
from services.file_service import FileService
|
||||
|
||||
from .errors import (
|
||||
from .error import (
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
@@ -37,6 +37,7 @@ class FileApi(Resource):
|
||||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
|
||||
}, 200
|
||||
|
||||
@setup_required
|
||||
@@ -1,25 +0,0 @@
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
class FileTooLargeError(BaseHTTPException):
|
||||
error_code = "file_too_large"
|
||||
description = "File size exceeded. {message}"
|
||||
code = 413
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(BaseHTTPException):
|
||||
error_code = "unsupported_file_type"
|
||||
description = "File type not allowed."
|
||||
code = 415
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = "too_many_files"
|
||||
description = "Only one file is allowed."
|
||||
code = 400
|
||||
|
||||
|
||||
class NoFileUploadedError(BaseHTTPException):
|
||||
error_code = "no_file_uploaded"
|
||||
description = "Please upload your file."
|
||||
code = 400
|
||||
@@ -1,9 +1,11 @@
|
||||
import urllib.parse
|
||||
from typing import cast
|
||||
|
||||
import httpx
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
|
||||
import services
|
||||
from controllers.common import helpers
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import ssrf_proxy
|
||||
@@ -11,19 +13,25 @@ from fields.file_fields import file_fields_with_signed_url, remote_file_info_fie
|
||||
from models.account import Account
|
||||
from services.file_service import FileService
|
||||
|
||||
from .error import (
|
||||
FileTooLargeError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
|
||||
|
||||
class RemoteFileInfoApi(Resource):
|
||||
@marshal_with(remote_file_info_fields)
|
||||
def get(self, url):
|
||||
decoded_url = urllib.parse.unquote(url)
|
||||
try:
|
||||
response = ssrf_proxy.head(decoded_url)
|
||||
return {
|
||||
"file_type": response.headers.get("Content-Type", "application/octet-stream"),
|
||||
"file_length": int(response.headers.get("Content-Length", 0)),
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 400
|
||||
resp = ssrf_proxy.head(decoded_url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
# failed back to get method
|
||||
resp = ssrf_proxy.get(decoded_url, timeout=3)
|
||||
resp.raise_for_status()
|
||||
return {
|
||||
"file_type": resp.headers.get("Content-Type", "application/octet-stream"),
|
||||
"file_length": int(resp.headers.get("Content-Length", 0)),
|
||||
}
|
||||
|
||||
|
||||
class RemoteFileUploadApi(Resource):
|
||||
@@ -35,17 +43,17 @@ class RemoteFileUploadApi(Resource):
|
||||
|
||||
url = args["url"]
|
||||
|
||||
response = ssrf_proxy.head(url)
|
||||
response.raise_for_status()
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3)
|
||||
resp.raise_for_status()
|
||||
|
||||
file_info = helpers.guess_file_info_from_response(response)
|
||||
file_info = helpers.guess_file_info_from_response(resp)
|
||||
|
||||
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
|
||||
return {"error": "File size exceeded"}, 400
|
||||
raise FileTooLargeError
|
||||
|
||||
response = ssrf_proxy.get(url)
|
||||
response.raise_for_status()
|
||||
content = response.content
|
||||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||
|
||||
try:
|
||||
user = cast(Account, current_user)
|
||||
@@ -56,8 +64,10 @@ class RemoteFileUploadApi(Resource):
|
||||
user=user,
|
||||
source_url=url,
|
||||
)
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 400
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from flask_restful import Resource, fields, marshal_with
|
||||
from flask_restful import Resource, marshal_with
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common import fields
|
||||
from controllers.common import helpers as controller_helpers
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
@@ -11,40 +12,8 @@ from services.app_service import AppService
|
||||
class AppParameterApi(Resource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
variable_fields = {
|
||||
"key": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"type": fields.String,
|
||||
"default": fields.String,
|
||||
"max_length": fields.Integer,
|
||||
"options": fields.List(fields.String),
|
||||
}
|
||||
|
||||
system_parameters_fields = {
|
||||
"image_file_size_limit": fields.Integer,
|
||||
"video_file_size_limit": fields.Integer,
|
||||
"audio_file_size_limit": fields.Integer,
|
||||
"file_size_limit": fields.Integer,
|
||||
}
|
||||
|
||||
parameters_fields = {
|
||||
"opening_statement": fields.String,
|
||||
"suggested_questions": fields.Raw,
|
||||
"suggested_questions_after_answer": fields.Raw,
|
||||
"speech_to_text": fields.Raw,
|
||||
"text_to_speech": fields.Raw,
|
||||
"retriever_resource": fields.Raw,
|
||||
"annotation_reply": fields.Raw,
|
||||
"more_like_this": fields.Raw,
|
||||
"user_input_form": fields.Raw,
|
||||
"sensitive_word_avoidance": fields.Raw,
|
||||
"file_upload": fields.Raw,
|
||||
"system_parameters": fields.Nested(system_parameters_fields),
|
||||
}
|
||||
|
||||
@validate_app_token
|
||||
@marshal_with(parameters_fields)
|
||||
@marshal_with(fields.parameters_fields)
|
||||
def get(self, app_model: App):
|
||||
"""Retrieve app parameters."""
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
@@ -56,43 +25,16 @@ class AppParameterApi(Resource):
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app_model.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
return {
|
||||
"opening_statement": features_dict.get("opening_statement"),
|
||||
"suggested_questions": features_dict.get("suggested_questions", []),
|
||||
"suggested_questions_after_answer": features_dict.get(
|
||||
"suggested_questions_after_answer", {"enabled": False}
|
||||
),
|
||||
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
|
||||
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
|
||||
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
|
||||
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
|
||||
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
|
||||
"user_input_form": user_input_form,
|
||||
"sensitive_word_avoidance": features_dict.get(
|
||||
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
|
||||
),
|
||||
"file_upload": features_dict.get(
|
||||
"file_upload",
|
||||
{
|
||||
"image": {
|
||||
"enabled": False,
|
||||
"number_limits": 3,
|
||||
"detail": "high",
|
||||
"transfer_methods": ["remote_url", "local_file"],
|
||||
}
|
||||
},
|
||||
),
|
||||
"system_parameters": {
|
||||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
||||
},
|
||||
}
|
||||
return controller_helpers.get_parameters_from_feature_dict(
|
||||
features_dict=features_dict, user_input_form=user_input_form
|
||||
)
|
||||
|
||||
|
||||
class AppMetaApi(Resource):
|
||||
|
||||
@@ -7,7 +7,11 @@ from controllers.service_api import api
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from fields.conversation_fields import (
|
||||
conversation_delete_fields,
|
||||
conversation_infinite_scroll_pagination_fields,
|
||||
simple_conversation_fields,
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
@@ -49,7 +53,7 @@ class ConversationApi(Resource):
|
||||
|
||||
class ConversationDetailApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@marshal_with(simple_conversation_fields)
|
||||
@marshal_with(conversation_delete_fields)
|
||||
def delete(self, app_model: App, end_user: EndUser, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
@@ -58,10 +62,9 @@ class ConversationDetailApi(Resource):
|
||||
conversation_id = str(c_id)
|
||||
|
||||
try:
|
||||
ConversationService.delete(app_model, conversation_id, end_user)
|
||||
return ConversationService.delete(app_model, conversation_id, end_user)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class ConversationRenameApi(Resource):
|
||||
|
||||
@@ -41,7 +41,6 @@ class FileApi(Resource):
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=end_user,
|
||||
source="datasets",
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from flask_restful import fields, marshal_with
|
||||
from flask_restful import marshal_with
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common import fields
|
||||
from controllers.common import helpers as controller_helpers
|
||||
from controllers.web import api
|
||||
from controllers.web.error import AppUnavailableError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
@@ -11,39 +12,7 @@ from services.app_service import AppService
|
||||
class AppParameterApi(WebApiResource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
variable_fields = {
|
||||
"key": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"type": fields.String,
|
||||
"default": fields.String,
|
||||
"max_length": fields.Integer,
|
||||
"options": fields.List(fields.String),
|
||||
}
|
||||
|
||||
system_parameters_fields = {
|
||||
"image_file_size_limit": fields.Integer,
|
||||
"video_file_size_limit": fields.Integer,
|
||||
"audio_file_size_limit": fields.Integer,
|
||||
"file_size_limit": fields.Integer,
|
||||
}
|
||||
|
||||
parameters_fields = {
|
||||
"opening_statement": fields.String,
|
||||
"suggested_questions": fields.Raw,
|
||||
"suggested_questions_after_answer": fields.Raw,
|
||||
"speech_to_text": fields.Raw,
|
||||
"text_to_speech": fields.Raw,
|
||||
"retriever_resource": fields.Raw,
|
||||
"annotation_reply": fields.Raw,
|
||||
"more_like_this": fields.Raw,
|
||||
"user_input_form": fields.Raw,
|
||||
"sensitive_word_avoidance": fields.Raw,
|
||||
"file_upload": fields.Raw,
|
||||
"system_parameters": fields.Nested(system_parameters_fields),
|
||||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
@marshal_with(fields.parameters_fields)
|
||||
def get(self, app_model: App, end_user):
|
||||
"""Retrieve app parameters."""
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
@@ -55,43 +24,16 @@ class AppParameterApi(WebApiResource):
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app_model.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
return {
|
||||
"opening_statement": features_dict.get("opening_statement"),
|
||||
"suggested_questions": features_dict.get("suggested_questions", []),
|
||||
"suggested_questions_after_answer": features_dict.get(
|
||||
"suggested_questions_after_answer", {"enabled": False}
|
||||
),
|
||||
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
|
||||
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
|
||||
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
|
||||
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
|
||||
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
|
||||
"user_input_form": user_input_form,
|
||||
"sensitive_word_avoidance": features_dict.get(
|
||||
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
|
||||
),
|
||||
"file_upload": features_dict.get(
|
||||
"file_upload",
|
||||
{
|
||||
"image": {
|
||||
"enabled": False,
|
||||
"number_limits": 3,
|
||||
"detail": "high",
|
||||
"transfer_methods": ["remote_url", "local_file"],
|
||||
}
|
||||
},
|
||||
),
|
||||
"system_parameters": {
|
||||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
||||
},
|
||||
}
|
||||
return controller_helpers.get_parameters_from_feature_dict(
|
||||
features_dict=features_dict, user_input_form=user_input_form
|
||||
)
|
||||
|
||||
|
||||
class AppMeta(WebApiResource):
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import urllib.parse
|
||||
|
||||
from flask_login import current_user
|
||||
import httpx
|
||||
from flask_restful import marshal_with, reqparse
|
||||
|
||||
import services
|
||||
from controllers.common import helpers
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.file import helpers as file_helpers
|
||||
@@ -10,52 +11,57 @@ from core.helper import ssrf_proxy
|
||||
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
||||
from services.file_service import FileService
|
||||
|
||||
from .error import FileTooLargeError, UnsupportedFileTypeError
|
||||
|
||||
|
||||
class RemoteFileInfoApi(WebApiResource):
|
||||
@marshal_with(remote_file_info_fields)
|
||||
def get(self, url):
|
||||
def get(self, app_model, end_user, url):
|
||||
decoded_url = urllib.parse.unquote(url)
|
||||
try:
|
||||
response = ssrf_proxy.head(decoded_url)
|
||||
return {
|
||||
"file_type": response.headers.get("Content-Type", "application/octet-stream"),
|
||||
"file_length": int(response.headers.get("Content-Length", -1)),
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 400
|
||||
resp = ssrf_proxy.head(decoded_url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
# failed back to get method
|
||||
resp = ssrf_proxy.get(decoded_url, timeout=3)
|
||||
resp.raise_for_status()
|
||||
return {
|
||||
"file_type": resp.headers.get("Content-Type", "application/octet-stream"),
|
||||
"file_length": int(resp.headers.get("Content-Length", -1)),
|
||||
}
|
||||
|
||||
|
||||
class RemoteFileUploadApi(WebApiResource):
|
||||
@marshal_with(file_fields_with_signed_url)
|
||||
def post(self):
|
||||
def post(self, app_model, end_user): # Add app_model and end_user parameters
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("url", type=str, required=True, help="URL is required")
|
||||
args = parser.parse_args()
|
||||
|
||||
url = args["url"]
|
||||
|
||||
response = ssrf_proxy.head(url)
|
||||
response.raise_for_status()
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3)
|
||||
resp.raise_for_status()
|
||||
|
||||
file_info = helpers.guess_file_info_from_response(response)
|
||||
file_info = helpers.guess_file_info_from_response(resp)
|
||||
|
||||
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
|
||||
return {"error": "File size exceeded"}, 400
|
||||
raise FileTooLargeError
|
||||
|
||||
response = ssrf_proxy.get(url)
|
||||
response.raise_for_status()
|
||||
content = response.content
|
||||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||
|
||||
try:
|
||||
upload_file = FileService.upload_file(
|
||||
filename=file_info.filename,
|
||||
content=content,
|
||||
mimetype=file_info.mimetype,
|
||||
user=current_user,
|
||||
user=end_user,
|
||||
source_url=url,
|
||||
)
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 400
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError
|
||||
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.file.models import FileExtraConfig
|
||||
from models import FileUploadConfig
|
||||
from core.file import FileExtraConfig
|
||||
|
||||
|
||||
class FileUploadConfigManager:
|
||||
@@ -43,6 +42,6 @@ class FileUploadConfigManager:
|
||||
if not config.get("file_upload"):
|
||||
config["file_upload"] = {}
|
||||
else:
|
||||
FileUploadConfig.model_validate(config["file_upload"])
|
||||
FileExtraConfig.model_validate(config["file_upload"])
|
||||
|
||||
return config, ["file_upload"]
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import contextvars
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
@@ -10,6 +9,7 @@ from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
@@ -317,7 +317,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
if os.environ.get("DEBUG", "false").lower() == "true":
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
|
||||
@@ -20,6 +20,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueIterationStartEvent,
|
||||
QueueMessageReplaceEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
@@ -241,7 +242,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
start_listener_time = time.time()
|
||||
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.exception(e)
|
||||
break
|
||||
if tts_publisher:
|
||||
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
|
||||
@@ -314,7 +315,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeFailedEvent):
|
||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
|
||||
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
@@ -8,6 +7,7 @@ from typing import Any, Literal, Union, overload
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
||||
from configs import dify_config
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
@@ -230,7 +230,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true":
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
|
||||
@@ -22,7 +22,10 @@ class BaseAppGenerator:
|
||||
user_inputs = user_inputs or {}
|
||||
# Filter input variables from form configuration, handle required fields, default values, and option values
|
||||
variables = app_config.variables
|
||||
user_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
|
||||
user_inputs = {
|
||||
var.variable: self._validate_inputs(value=user_inputs.get(var.variable), variable_entity=var)
|
||||
for var in variables
|
||||
}
|
||||
user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()}
|
||||
# Convert files in inputs to File
|
||||
entity_dictionary = {item.variable: item for item in app_config.variables}
|
||||
@@ -74,50 +77,66 @@ class BaseAppGenerator:
|
||||
|
||||
return user_inputs
|
||||
|
||||
def _validate_input(self, *, inputs: Mapping[str, Any], var: "VariableEntity"):
|
||||
user_input_value = inputs.get(var.variable)
|
||||
if not user_input_value:
|
||||
if var.required:
|
||||
raise ValueError(f"{var.variable} is required in input form")
|
||||
else:
|
||||
return None
|
||||
def _validate_inputs(
|
||||
self,
|
||||
*,
|
||||
variable_entity: "VariableEntity",
|
||||
value: Any,
|
||||
):
|
||||
if value is None:
|
||||
if variable_entity.required:
|
||||
raise ValueError(f"{variable_entity.variable} is required in input form")
|
||||
return value
|
||||
|
||||
if var.type in {
|
||||
if variable_entity.type in {
|
||||
VariableEntityType.TEXT_INPUT,
|
||||
VariableEntityType.SELECT,
|
||||
VariableEntityType.PARAGRAPH,
|
||||
} and not isinstance(user_input_value, str):
|
||||
raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string")
|
||||
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
|
||||
} and not isinstance(value, str):
|
||||
raise ValueError(
|
||||
f"(type '{variable_entity.type}') {variable_entity.variable} in input form must be a string"
|
||||
)
|
||||
|
||||
if variable_entity.type == VariableEntityType.NUMBER and isinstance(value, str):
|
||||
# may raise ValueError if user_input_value is not a valid number
|
||||
try:
|
||||
if "." in user_input_value:
|
||||
return float(user_input_value)
|
||||
if "." in value:
|
||||
return float(value)
|
||||
else:
|
||||
return int(user_input_value)
|
||||
return int(value)
|
||||
except ValueError:
|
||||
raise ValueError(f"{var.variable} in input form must be a valid number")
|
||||
if var.type == VariableEntityType.SELECT:
|
||||
options = var.options
|
||||
if user_input_value not in options:
|
||||
raise ValueError(f"{var.variable} in input form must be one of the following: {options}")
|
||||
elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}:
|
||||
if var.max_length and len(user_input_value) > var.max_length:
|
||||
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")
|
||||
elif var.type == VariableEntityType.FILE:
|
||||
if not isinstance(user_input_value, dict) and not isinstance(user_input_value, File):
|
||||
raise ValueError(f"{var.variable} in input form must be a file")
|
||||
elif var.type == VariableEntityType.FILE_LIST:
|
||||
if not (
|
||||
isinstance(user_input_value, list)
|
||||
and (
|
||||
all(isinstance(item, dict) for item in user_input_value)
|
||||
or all(isinstance(item, File) for item in user_input_value)
|
||||
)
|
||||
):
|
||||
raise ValueError(f"{var.variable} in input form must be a list of files")
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a valid number")
|
||||
|
||||
return user_input_value
|
||||
match variable_entity.type:
|
||||
case VariableEntityType.SELECT:
|
||||
if value not in variable_entity.options:
|
||||
raise ValueError(
|
||||
f"{variable_entity.variable} in input form must be one of the following: "
|
||||
f"{variable_entity.options}"
|
||||
)
|
||||
case VariableEntityType.TEXT_INPUT | VariableEntityType.PARAGRAPH:
|
||||
if variable_entity.max_length and len(value) > variable_entity.max_length:
|
||||
raise ValueError(
|
||||
f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} "
|
||||
"characters"
|
||||
)
|
||||
case VariableEntityType.FILE:
|
||||
if not isinstance(value, dict) and not isinstance(value, File):
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a file")
|
||||
case VariableEntityType.FILE_LIST:
|
||||
# if number of files exceeds the limit, raise ValueError
|
||||
if not (
|
||||
isinstance(value, list)
|
||||
and (all(isinstance(item, dict) for item in value) or all(isinstance(item, File) for item in value))
|
||||
):
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a list of files")
|
||||
|
||||
if variable_entity.max_length and len(value) > variable_entity.max_length:
|
||||
raise ValueError(
|
||||
f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} files"
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
def _sanitize_value(self, value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
@@ -8,6 +7,7 @@ from typing import Any, Literal, Union, overload
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
||||
from configs import dify_config
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
@@ -227,7 +227,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true":
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
@@ -8,6 +7,7 @@ from typing import Any, Literal, Union, overload
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
@@ -203,7 +203,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true":
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import contextvars
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
@@ -10,6 +9,7 @@ from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
@@ -261,7 +261,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == "true":
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
|
||||
@@ -16,6 +16,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
@@ -215,7 +216,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
else:
|
||||
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.exception(e)
|
||||
break
|
||||
if tts_publisher:
|
||||
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
|
||||
@@ -275,7 +276,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeFailedEvent):
|
||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
|
||||
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
|
||||
@@ -9,6 +9,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
@@ -30,6 +31,7 @@ from core.workflow.graph_engine.entities.event import (
|
||||
IterationRunNextEvent,
|
||||
IterationRunStartedEvent,
|
||||
IterationRunSucceededEvent,
|
||||
NodeInIterationFailedEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunStartedEvent,
|
||||
@@ -193,6 +195,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
node_run_index=event.route_node_state.index,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
@@ -246,9 +249,40 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
error=event.route_node_state.node_run_result.error
|
||||
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
|
||||
else "Unknown error",
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeInIterationFailedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeInIterationFailedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
error=event.error,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
self._publish_event(
|
||||
QueueTextChunkEvent(
|
||||
@@ -326,6 +360,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
index=event.index,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
output=event.pre_iteration_output,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)):
|
||||
|
||||
@@ -107,7 +107,8 @@ class QueueIterationNextEvent(AppQueueEvent):
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""iteratoin run in parallel mode run id"""
|
||||
node_run_index: int
|
||||
output: Optional[Any] = None # output for the current iteration
|
||||
|
||||
@@ -273,6 +274,8 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""iteratoin run in parallel mode run id"""
|
||||
|
||||
|
||||
class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
@@ -306,6 +309,37 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class QueueNodeInIterationFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeInIterationFailedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.NODE_FAILED
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
process_data: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
|
||||
class QueueNodeFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeFailedEvent entity
|
||||
@@ -332,6 +366,7 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
process_data: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
|
||||
@@ -244,6 +244,7 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
parallel_run_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_STARTED
|
||||
workflow_run_id: str
|
||||
@@ -432,6 +433,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
|
||||
extras: dict = {}
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_NEXT
|
||||
workflow_run_id: str
|
||||
|
||||
@@ -12,6 +12,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
@@ -35,6 +36,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
@@ -251,6 +253,12 @@ class WorkflowCycleManage:
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
|
||||
workflow_node_execution.created_by_role = workflow_run.created_by_role
|
||||
workflow_node_execution.created_by = workflow_run.created_by
|
||||
workflow_node_execution.execution_metadata = json.dumps(
|
||||
{
|
||||
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
|
||||
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
}
|
||||
)
|
||||
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
session.add(workflow_node_execution)
|
||||
@@ -305,7 +313,9 @@ class WorkflowCycleManage:
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) -> WorkflowNodeExecution:
|
||||
def _handle_workflow_node_execution_failed(
|
||||
self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution failed
|
||||
:param event: queue node failed event
|
||||
@@ -318,16 +328,19 @@ class WorkflowCycleManage:
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||
|
||||
execution_metadata = (
|
||||
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
|
||||
)
|
||||
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
|
||||
{
|
||||
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
|
||||
WorkflowNodeExecution.error: event.error,
|
||||
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
|
||||
WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None,
|
||||
WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None,
|
||||
WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
|
||||
WorkflowNodeExecution.finished_at: finished_at,
|
||||
WorkflowNodeExecution.elapsed_time: elapsed_time,
|
||||
WorkflowNodeExecution.execution_metadata: execution_metadata,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -342,6 +355,7 @@ class WorkflowCycleManage:
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.finished_at = finished_at
|
||||
workflow_node_execution.elapsed_time = elapsed_time
|
||||
workflow_node_execution.execution_metadata = execution_metadata
|
||||
|
||||
self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id)
|
||||
|
||||
@@ -448,6 +462,7 @@ class WorkflowCycleManage:
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
parallel_run_id=event.parallel_mode_run_id,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -464,7 +479,7 @@ class WorkflowCycleManage:
|
||||
|
||||
def _workflow_node_finish_to_stream_response(
|
||||
self,
|
||||
event: QueueNodeSucceededEvent | QueueNodeFailedEvent,
|
||||
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[NodeFinishStreamResponse]:
|
||||
@@ -608,6 +623,7 @@ class WorkflowCycleManage:
|
||||
extras={},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -633,7 +649,9 @@ class WorkflowCycleManage:
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
if event.error is None
|
||||
else WorkflowNodeExecutionStatus.FAILED,
|
||||
error=None,
|
||||
elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(),
|
||||
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
|
||||
|
||||
@@ -3,22 +3,20 @@ Proxy requests to avoid SSRF
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
import httpx
|
||||
|
||||
SSRF_PROXY_ALL_URL = os.getenv("SSRF_PROXY_ALL_URL", "")
|
||||
SSRF_PROXY_HTTP_URL = os.getenv("SSRF_PROXY_HTTP_URL", "")
|
||||
SSRF_PROXY_HTTPS_URL = os.getenv("SSRF_PROXY_HTTPS_URL", "")
|
||||
SSRF_DEFAULT_MAX_RETRIES = int(os.getenv("SSRF_DEFAULT_MAX_RETRIES", "3"))
|
||||
from configs import dify_config
|
||||
|
||||
SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
|
||||
|
||||
proxy_mounts = (
|
||||
{
|
||||
"http://": httpx.HTTPTransport(proxy=SSRF_PROXY_HTTP_URL),
|
||||
"https://": httpx.HTTPTransport(proxy=SSRF_PROXY_HTTPS_URL),
|
||||
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL),
|
||||
"https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL),
|
||||
}
|
||||
if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL
|
||||
if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL
|
||||
else None
|
||||
)
|
||||
|
||||
@@ -32,11 +30,19 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
if "follow_redirects" not in kwargs:
|
||||
kwargs["follow_redirects"] = allow_redirects
|
||||
|
||||
if "timeout" not in kwargs:
|
||||
kwargs["timeout"] = httpx.Timeout(
|
||||
timeout=dify_config.SSRF_DEFAULT_TIME_OUT,
|
||||
connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT,
|
||||
read=dify_config.SSRF_DEFAULT_READ_TIME_OUT,
|
||||
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
|
||||
)
|
||||
|
||||
retries = 0
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
if SSRF_PROXY_ALL_URL:
|
||||
with httpx.Client(proxy=SSRF_PROXY_ALL_URL) as client:
|
||||
if dify_config.SSRF_PROXY_ALL_URL:
|
||||
with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL) as client:
|
||||
response = client.request(method=method, url=url, **kwargs)
|
||||
elif proxy_mounts:
|
||||
with httpx.Client(mounts=proxy_mounts) as client:
|
||||
|
||||
@@ -598,7 +598,7 @@ class IndexingRunner:
|
||||
rules = DatasetProcessRule.AUTOMATIC_RULES
|
||||
else:
|
||||
rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
|
||||
document_text = CleanProcessor.clean(text, rules)
|
||||
document_text = CleanProcessor.clean(text, {"rules": rules})
|
||||
|
||||
return document_text
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Callable, Generator, Iterable, Sequence
|
||||
from typing import IO, Any, Optional, Union, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
from core.entities.provider_entities import ModelLoadBalancingConfiguration
|
||||
@@ -473,7 +473,7 @@ class LBModelManager:
|
||||
|
||||
continue
|
||||
|
||||
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
|
||||
if dify_config.DEBUG:
|
||||
logger.info(
|
||||
f"Model LB\nid: {config.id}\nname:{config.name}\n"
|
||||
f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n"
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
- claude-3-5-haiku-20241022
|
||||
- claude-3-5-sonnet-20241022
|
||||
- claude-3-5-sonnet-20240620
|
||||
- claude-3-haiku-20240307
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
model: claude-3-5-haiku-20241022
|
||||
label:
|
||||
en_US: claude-3-5-haiku-20241022
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '1.00'
|
||||
output: '5.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -47,9 +47,9 @@ class AzureRerankModel(RerankModel):
|
||||
result = response.read()
|
||||
return json.loads(result)
|
||||
except urllib.error.HTTPError as error:
|
||||
logger.error(f"The request failed with status code: {error.code}")
|
||||
logger.error(error.info())
|
||||
logger.error(error.read().decode("utf8", "ignore"))
|
||||
logger.exception(f"The request failed with status code: {error.code}")
|
||||
logger.exception(error.info())
|
||||
logger.exception(error.read().decode("utf8", "ignore"))
|
||||
raise
|
||||
|
||||
def _invoke(
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
model: anthropic.claude-3-5-haiku-20241022-v1:0
|
||||
label:
|
||||
en_US: Claude 3.5 Haiku
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
# docs: https://docs.anthropic.com/claude/docs/system-prompts
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.001'
|
||||
output: '0.005'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@@ -0,0 +1,60 @@
|
||||
model: us.anthropic.claude-3-5-haiku-20241022-v1:0
|
||||
label:
|
||||
en_US: Claude 3.5 Haiku(US.Cross Region Inference)
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
# docs: https://docs.anthropic.com/claude/docs/system-prompts
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.001'
|
||||
output: '0.005'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
import requests
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
@@ -16,8 +17,18 @@ class GiteeAIProvider(ModelProvider):
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
model_instance.validate_credentials(model="Qwen2-7B-Instruct", credentials=credentials)
|
||||
api_key = credentials.get("api_key")
|
||||
if not api_key:
|
||||
raise CredentialsValidateFailedError("Credentials validation failed: api_key not given")
|
||||
|
||||
# send a get request to validate the credentials
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
response = requests.get("https://ai.gitee.com/api/base/account/me", headers=headers, timeout=(10, 300))
|
||||
|
||||
if response.status_code != 200:
|
||||
raise CredentialsValidateFailedError(
|
||||
f"Credentials validation failed with status code {response.status_code}"
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
model: anthropic/claude-3-5-haiku
|
||||
label:
|
||||
en_US: claude-3-5-haiku
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: "1"
|
||||
output: "5"
|
||||
unit: "0.000001"
|
||||
currency: USD
|
||||
@@ -5,6 +5,7 @@ label:
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
|
||||
@@ -5,6 +5,7 @@ label:
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
|
||||
@@ -4,10 +4,7 @@ import re
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
# from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
import boto3
|
||||
from sagemaker import Predictor, serializers
|
||||
from sagemaker.session import Session
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
@@ -212,6 +209,9 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
from sagemaker import Predictor, serializers
|
||||
from sagemaker.session import Session
|
||||
|
||||
if not self.sagemaker_session:
|
||||
access_key = credentials.get("aws_access_key_id")
|
||||
secret_key = credentials.get("aws_secret_access_key")
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" fill="currentColor" viewBox="0 0 24 24" aria-hidden="true" class="" focusable="false" style="fill:currentColor;height:28px;width:28px"><path d="m3.005 8.858 8.783 12.544h3.904L6.908 8.858zM6.905 15.825 3 21.402h3.907l1.951-2.788zM16.585 2l-6.75 9.64 1.953 2.79L20.492 2zM17.292 7.965v13.437h3.2V3.395z"></path></svg>
|
||||
|
After Width: | Height: | Size: 356 B |
63
api/core/model_runtime/model_providers/x/llm/grok-beta.yaml
Normal file
63
api/core/model_runtime/model_providers/x/llm/grok-beta.yaml
Normal file
@@ -0,0 +1,63 @@
|
||||
model: grok-beta
|
||||
label:
|
||||
en_US: Grok beta
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
label:
|
||||
en_US: "Temperature"
|
||||
zh_Hans: "采样温度"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_p
|
||||
label:
|
||||
en_US: "Top P"
|
||||
zh_Hans: "Top P"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
label:
|
||||
en_US: "Frequency Penalty"
|
||||
zh_Hans: "频率惩罚"
|
||||
type: float
|
||||
default: 0
|
||||
min: 0
|
||||
max: 2.0
|
||||
precision: 1
|
||||
required: false
|
||||
help:
|
||||
en_US: "Number between 0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim."
|
||||
zh_Hans: "介于0和2.0之间的数字。正值会根据新标记在文本中迄今为止的现有频率来惩罚它们,从而降低模型一字不差地重复同一句话的可能性。"
|
||||
|
||||
- name: user
|
||||
use_template: text
|
||||
label:
|
||||
en_US: "User"
|
||||
zh_Hans: "用户"
|
||||
type: string
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to track and differentiate conversation requests from different users."
|
||||
zh_Hans: "用于追踪和区分不同用户的对话请求。"
|
||||
37
api/core/model_runtime/model_providers/x/llm/llm.py
Normal file
37
api/core/model_runtime/model_providers/x/llm/llm.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from yarl import URL
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
)
|
||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||
|
||||
|
||||
class XAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials) -> None:
|
||||
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"])) or "https://api.x.ai/v1"
|
||||
credentials["mode"] = LLMMode.CHAT.value
|
||||
credentials["function_calling_type"] = "tool_call"
|
||||
25
api/core/model_runtime/model_providers/x/x.py
Normal file
25
api/core/model_runtime/model_providers/x/x.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class XAIProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
model_instance.validate_credentials(model="grok-beta", credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
|
||||
raise ex
|
||||
38
api/core/model_runtime/model_providers/x/x.yaml
Normal file
38
api/core/model_runtime/model_providers/x/x.yaml
Normal file
@@ -0,0 +1,38 @@
|
||||
provider: x
|
||||
label:
|
||||
en_US: xAI
|
||||
description:
|
||||
en_US: xAI is a company working on building artificial intelligence to accelerate human scientific discovery. We are guided by our mission to advance our collective understanding of the universe.
|
||||
icon_small:
|
||||
en_US: x-ai-logo.svg
|
||||
icon_large:
|
||||
en_US: x-ai-logo.svg
|
||||
help:
|
||||
title:
|
||||
en_US: Get your token from xAI
|
||||
zh_Hans: 从 xAI 获取 token
|
||||
url:
|
||||
en_US: https://x.ai/api
|
||||
supported_model_types:
|
||||
- llm
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: endpoint_url
|
||||
label:
|
||||
en_US: API Base
|
||||
type: text-input
|
||||
required: false
|
||||
default: https://api.x.ai/v1
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Base
|
||||
en_US: Enter your API Base
|
||||
@@ -126,6 +126,6 @@ class OutputModeration(BaseModel):
|
||||
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error("Moderation Output error: %s", e)
|
||||
logger.exception("Moderation Output error: %s", e)
|
||||
|
||||
return None
|
||||
|
||||
@@ -708,7 +708,7 @@ class TraceQueueManager:
|
||||
trace_task.app_id = self.app_id
|
||||
trace_manager_queue.put(trace_task)
|
||||
except Exception as e:
|
||||
logging.error(f"Error adding trace task: {e}")
|
||||
logging.exception(f"Error adding trace task: {e}")
|
||||
finally:
|
||||
self.start_timer()
|
||||
|
||||
@@ -727,7 +727,7 @@ class TraceQueueManager:
|
||||
if tasks:
|
||||
self.send_to_celery(tasks)
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing trace tasks: {e}")
|
||||
logging.exception(f"Error processing trace tasks: {e}")
|
||||
|
||||
def start_timer(self):
|
||||
global trace_manager_timer
|
||||
|
||||
@@ -242,7 +242,7 @@ class CouchbaseVector(BaseVector):
|
||||
try:
|
||||
self._cluster.query(query, named_parameters={"doc_ids": ids}).execute()
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.exception(e)
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
query = f"""
|
||||
|
||||
0
api/core/rag/datasource/vdb/lindorm/__init__.py
Normal file
0
api/core/rag/datasource/vdb/lindorm/__init__.py
Normal file
498
api/core/rag/datasource/vdb/lindorm/lindorm_vector.py
Normal file
498
api/core/rag/datasource/vdb/lindorm/lindorm_vector.py
Normal file
@@ -0,0 +1,498 @@
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Optional
|
||||
|
||||
from opensearchpy import OpenSearch
|
||||
from opensearchpy.helpers import bulk
|
||||
from pydantic import BaseModel, model_validator
|
||||
from tenacity import retry, stop_after_attempt, wait_fixed
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logging.getLogger("lindorm").setLevel(logging.WARN)
|
||||
|
||||
|
||||
class LindormVectorStoreConfig(BaseModel):
|
||||
hosts: str
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["hosts"]:
|
||||
raise ValueError("config URL is required")
|
||||
if not values["username"]:
|
||||
raise ValueError("config USERNAME is required")
|
||||
if not values["password"]:
|
||||
raise ValueError("config PASSWORD is required")
|
||||
return values
|
||||
|
||||
def to_opensearch_params(self) -> dict[str, Any]:
|
||||
params = {
|
||||
"hosts": self.hosts,
|
||||
}
|
||||
if self.username and self.password:
|
||||
params["http_auth"] = (self.username, self.password)
|
||||
return params
|
||||
|
||||
|
||||
class LindormVectorStore(BaseVector):
|
||||
def __init__(self, collection_name: str, config: LindormVectorStoreConfig, **kwargs):
|
||||
super().__init__(collection_name.lower())
|
||||
self._client_config = config
|
||||
self._client = OpenSearch(**config.to_opensearch_params())
|
||||
self.kwargs = kwargs
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.LINDORM
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
self.create_collection(len(embeddings[0]), **kwargs)
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def refresh(self):
|
||||
self._client.indices.refresh(index=self._collection_name)
|
||||
|
||||
def __filter_existed_ids(
|
||||
self,
|
||||
texts: list[str],
|
||||
metadatas: list[dict],
|
||||
ids: list[str],
|
||||
bulk_size: int = 1024,
|
||||
) -> tuple[Iterable[str], Optional[list[dict]], Optional[list[str]]]:
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(60))
|
||||
def __fetch_existing_ids(batch_ids: list[str]) -> set[str]:
|
||||
try:
|
||||
existing_docs = self._client.mget(index=self._collection_name, body={"ids": batch_ids}, _source=False)
|
||||
return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
|
||||
except Exception as e:
|
||||
logger.exception(f"Error fetching batch {batch_ids}: {e}")
|
||||
return set()
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(60))
|
||||
def __fetch_existing_routing_ids(batch_ids: list[str], route_ids: list[str]) -> set[str]:
|
||||
try:
|
||||
existing_docs = self._client.mget(
|
||||
body={
|
||||
"docs": [
|
||||
{"_index": self._collection_name, "_id": id, "routing": routing}
|
||||
for id, routing in zip(batch_ids, route_ids)
|
||||
]
|
||||
},
|
||||
_source=False,
|
||||
)
|
||||
return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
|
||||
except Exception as e:
|
||||
logger.exception(f"Error fetching batch {batch_ids}: {e}")
|
||||
return set()
|
||||
|
||||
if ids is None:
|
||||
return texts, metadatas, ids
|
||||
|
||||
if len(texts) != len(ids):
|
||||
raise RuntimeError(f"texts {len(texts)} != {ids}")
|
||||
|
||||
filtered_texts = []
|
||||
filtered_metadatas = []
|
||||
filtered_ids = []
|
||||
|
||||
def batch(iterable, n):
|
||||
length = len(iterable)
|
||||
for idx in range(0, length, n):
|
||||
yield iterable[idx : min(idx + n, length)]
|
||||
|
||||
for ids_batch, texts_batch, metadatas_batch in zip(
|
||||
batch(ids, bulk_size),
|
||||
batch(texts, bulk_size),
|
||||
batch(metadatas, bulk_size) if metadatas is not None else batch([None] * len(ids), bulk_size),
|
||||
):
|
||||
existing_ids_set = __fetch_existing_ids(ids_batch)
|
||||
for text, metadata, doc_id in zip(texts_batch, metadatas_batch, ids_batch):
|
||||
if doc_id not in existing_ids_set:
|
||||
filtered_texts.append(text)
|
||||
filtered_ids.append(doc_id)
|
||||
if metadatas is not None:
|
||||
filtered_metadatas.append(metadata)
|
||||
|
||||
return filtered_texts, metadatas if metadatas is None else filtered_metadatas, filtered_ids
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
actions = []
|
||||
uuids = self._get_uuids(documents)
|
||||
for i in range(len(documents)):
|
||||
action = {
|
||||
"_op_type": "index",
|
||||
"_index": self._collection_name.lower(),
|
||||
"_id": uuids[i],
|
||||
"_source": {
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
|
||||
Field.METADATA_KEY.value: documents[i].metadata,
|
||||
},
|
||||
}
|
||||
actions.append(action)
|
||||
bulk(self._client, actions)
|
||||
self.refresh()
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}}
|
||||
response = self._client.search(index=self._collection_name, body=query)
|
||||
if response["hits"]["hits"]:
|
||||
return [hit["_id"] for hit in response["hits"]["hits"]]
|
||||
else:
|
||||
return None
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}}
|
||||
results = self._client.search(index=self._collection_name, body=query_str)
|
||||
ids = [hit["_id"] for hit in results["hits"]["hits"]]
|
||||
if ids:
|
||||
self.delete_by_ids(ids)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
for id in ids:
|
||||
if self._client.exists(index=self._collection_name, id=id):
|
||||
self._client.delete(index=self._collection_name, id=id)
|
||||
else:
|
||||
logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.")
|
||||
|
||||
def delete(self) -> None:
|
||||
try:
|
||||
if self._client.indices.exists(index=self._collection_name):
|
||||
self._client.indices.delete(index=self._collection_name, params={"timeout": 60})
|
||||
logger.info("Delete index success")
|
||||
else:
|
||||
logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error occurred while deleting the index: {e}")
|
||||
raise e
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
try:
|
||||
self._client.get(index=self._collection_name, id=id)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
# Make sure query_vector is a list
|
||||
if not isinstance(query_vector, list):
|
||||
raise ValueError("query_vector should be a list of floats")
|
||||
|
||||
# Check whether query_vector is a floating-point number list
|
||||
if not all(isinstance(x, float) for x in query_vector):
|
||||
raise ValueError("All elements in query_vector should be floats")
|
||||
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs)
|
||||
try:
|
||||
response = self._client.search(index=self._collection_name, body=query)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error executing search: {e}")
|
||||
raise
|
||||
|
||||
docs_and_scores = []
|
||||
for hit in response["hits"]["hits"]:
|
||||
docs_and_scores.append(
|
||||
(
|
||||
Document(
|
||||
page_content=hit["_source"][Field.CONTENT_KEY.value],
|
||||
vector=hit["_source"][Field.VECTOR.value],
|
||||
metadata=hit["_source"][Field.METADATA_KEY.value],
|
||||
),
|
||||
hit["_score"],
|
||||
)
|
||||
)
|
||||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) or 0.0
|
||||
if score > score_threshold:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
must = kwargs.get("must")
|
||||
must_not = kwargs.get("must_not")
|
||||
should = kwargs.get("should")
|
||||
minimum_should_match = kwargs.get("minimum_should_match", 0)
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
filters = kwargs.get("filter")
|
||||
routing = kwargs.get("routing")
|
||||
full_text_query = default_text_search_query(
|
||||
query_text=query,
|
||||
k=top_k,
|
||||
text_field=Field.CONTENT_KEY.value,
|
||||
must=must,
|
||||
must_not=must_not,
|
||||
should=should,
|
||||
minimum_should_match=minimum_should_match,
|
||||
filters=filters,
|
||||
routing=routing,
|
||||
)
|
||||
response = self._client.search(index=self._collection_name, body=full_text_query)
|
||||
docs = []
|
||||
for hit in response["hits"]["hits"]:
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=hit["_source"][Field.CONTENT_KEY.value],
|
||||
vector=hit["_source"][Field.VECTOR.value],
|
||||
metadata=hit["_source"][Field.METADATA_KEY.value],
|
||||
)
|
||||
)
|
||||
|
||||
return docs
|
||||
|
||||
def create_collection(self, dimension: int, **kwargs):
|
||||
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
logger.info(f"Collection {self._collection_name} already exists.")
|
||||
return
|
||||
if self._client.indices.exists(index=self._collection_name):
|
||||
logger.info("{self._collection_name.lower()} already exists.")
|
||||
return
|
||||
if len(self.kwargs) == 0 and len(kwargs) != 0:
|
||||
self.kwargs = copy.deepcopy(kwargs)
|
||||
vector_field = kwargs.pop("vector_field", Field.VECTOR.value)
|
||||
shards = kwargs.pop("shards", 2)
|
||||
|
||||
engine = kwargs.pop("engine", "lvector")
|
||||
method_name = kwargs.pop("method_name", "hnsw")
|
||||
data_type = kwargs.pop("data_type", "float")
|
||||
space_type = kwargs.pop("space_type", "cosinesimil")
|
||||
|
||||
hnsw_m = kwargs.pop("hnsw_m", 24)
|
||||
hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500)
|
||||
ivfpq_m = kwargs.pop("ivfpq_m", dimension)
|
||||
nlist = kwargs.pop("nlist", 1000)
|
||||
centroids_use_hnsw = kwargs.pop("centroids_use_hnsw", True if nlist >= 5000 else False)
|
||||
centroids_hnsw_m = kwargs.pop("centroids_hnsw_m", 24)
|
||||
centroids_hnsw_ef_construct = kwargs.pop("centroids_hnsw_ef_construct", 500)
|
||||
centroids_hnsw_ef_search = kwargs.pop("centroids_hnsw_ef_search", 100)
|
||||
mapping = default_text_mapping(
|
||||
dimension,
|
||||
method_name,
|
||||
shards=shards,
|
||||
engine=engine,
|
||||
data_type=data_type,
|
||||
space_type=space_type,
|
||||
vector_field=vector_field,
|
||||
hnsw_m=hnsw_m,
|
||||
hnsw_ef_construction=hnsw_ef_construction,
|
||||
nlist=nlist,
|
||||
ivfpq_m=ivfpq_m,
|
||||
centroids_use_hnsw=centroids_use_hnsw,
|
||||
centroids_hnsw_m=centroids_hnsw_m,
|
||||
centroids_hnsw_ef_construct=centroids_hnsw_ef_construct,
|
||||
centroids_hnsw_ef_search=centroids_hnsw_ef_search,
|
||||
**kwargs,
|
||||
)
|
||||
self._client.indices.create(index=self._collection_name.lower(), body=mapping)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
# logger.info(f"create index success: {self._collection_name}")
|
||||
|
||||
|
||||
def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dict:
|
||||
routing_field = kwargs.get("routing_field")
|
||||
excludes_from_source = kwargs.get("excludes_from_source")
|
||||
analyzer = kwargs.get("analyzer", "ik_max_word")
|
||||
text_field = kwargs.get("text_field", Field.CONTENT_KEY.value)
|
||||
engine = kwargs["engine"]
|
||||
shard = kwargs["shards"]
|
||||
space_type = kwargs["space_type"]
|
||||
data_type = kwargs["data_type"]
|
||||
vector_field = kwargs.get("vector_field", Field.VECTOR.value)
|
||||
|
||||
if method_name == "ivfpq":
|
||||
ivfpq_m = kwargs["ivfpq_m"]
|
||||
nlist = kwargs["nlist"]
|
||||
centroids_use_hnsw = True if nlist > 10000 else False
|
||||
centroids_hnsw_m = 24
|
||||
centroids_hnsw_ef_construct = 500
|
||||
centroids_hnsw_ef_search = 100
|
||||
parameters = {
|
||||
"m": ivfpq_m,
|
||||
"nlist": nlist,
|
||||
"centroids_use_hnsw": centroids_use_hnsw,
|
||||
"centroids_hnsw_m": centroids_hnsw_m,
|
||||
"centroids_hnsw_ef_construct": centroids_hnsw_ef_construct,
|
||||
"centroids_hnsw_ef_search": centroids_hnsw_ef_search,
|
||||
}
|
||||
elif method_name == "hnsw":
|
||||
neighbor = kwargs["hnsw_m"]
|
||||
ef_construction = kwargs["hnsw_ef_construction"]
|
||||
parameters = {"m": neighbor, "ef_construction": ef_construction}
|
||||
elif method_name == "flat":
|
||||
parameters = {}
|
||||
else:
|
||||
raise RuntimeError(f"unexpected method_name: {method_name}")
|
||||
|
||||
mapping = {
|
||||
"settings": {"index": {"number_of_shards": shard, "knn": True}},
|
||||
"mappings": {
|
||||
"properties": {
|
||||
vector_field: {
|
||||
"type": "knn_vector",
|
||||
"dimension": dimension,
|
||||
"data_type": data_type,
|
||||
"method": {
|
||||
"engine": engine,
|
||||
"name": method_name,
|
||||
"space_type": space_type,
|
||||
"parameters": parameters,
|
||||
},
|
||||
},
|
||||
text_field: {"type": "text", "analyzer": analyzer},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
if excludes_from_source:
|
||||
mapping["mappings"]["_source"] = {"excludes": excludes_from_source} # e.g. {"excludes": ["vector_field"]}
|
||||
|
||||
if method_name == "ivfpq" and routing_field is not None:
|
||||
mapping["settings"]["index"]["knn_routing"] = True
|
||||
mapping["settings"]["index"]["knn.offline.construction"] = True
|
||||
|
||||
if method_name == "flat" and routing_field is not None:
|
||||
mapping["settings"]["index"]["knn_routing"] = True
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def default_text_search_query(
|
||||
query_text: str,
|
||||
k: int = 4,
|
||||
text_field: str = Field.CONTENT_KEY.value,
|
||||
must: Optional[list[dict]] = None,
|
||||
must_not: Optional[list[dict]] = None,
|
||||
should: Optional[list[dict]] = None,
|
||||
minimum_should_match: int = 0,
|
||||
filters: Optional[list[dict]] = None,
|
||||
routing: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
if routing is not None:
|
||||
routing_field = kwargs.get("routing_field", "routing_field")
|
||||
query_clause = {
|
||||
"bool": {
|
||||
"must": [{"match": {text_field: query_text}}, {"term": {f"metadata.{routing_field}.keyword": routing}}]
|
||||
}
|
||||
}
|
||||
else:
|
||||
query_clause = {"match": {text_field: query_text}}
|
||||
# build the simplest search_query when only query_text is specified
|
||||
if not must and not must_not and not should and not filters:
|
||||
search_query = {"size": k, "query": query_clause}
|
||||
return search_query
|
||||
|
||||
# build complex search_query when either of must/must_not/should/filter is specified
|
||||
if must:
|
||||
if not isinstance(must, list):
|
||||
raise RuntimeError(f"unexpected [must] clause with {type(filters)}")
|
||||
if query_clause not in must:
|
||||
must.append(query_clause)
|
||||
else:
|
||||
must = [query_clause]
|
||||
|
||||
boolean_query = {"must": must}
|
||||
|
||||
if must_not:
|
||||
if not isinstance(must_not, list):
|
||||
raise RuntimeError(f"unexpected [must_not] clause with {type(filters)}")
|
||||
boolean_query["must_not"] = must_not
|
||||
|
||||
if should:
|
||||
if not isinstance(should, list):
|
||||
raise RuntimeError(f"unexpected [should] clause with {type(filters)}")
|
||||
boolean_query["should"] = should
|
||||
if minimum_should_match != 0:
|
||||
boolean_query["minimum_should_match"] = minimum_should_match
|
||||
|
||||
if filters:
|
||||
if not isinstance(filters, list):
|
||||
raise RuntimeError(f"unexpected [filter] clause with {type(filters)}")
|
||||
boolean_query["filter"] = filters
|
||||
|
||||
search_query = {"size": k, "query": {"bool": boolean_query}}
|
||||
return search_query
|
||||
|
||||
|
||||
def default_vector_search_query(
|
||||
query_vector: list[float],
|
||||
k: int = 4,
|
||||
min_score: str = "0.0",
|
||||
ef_search: Optional[str] = None, # only for hnsw
|
||||
nprobe: Optional[str] = None, # "2000"
|
||||
reorder_factor: Optional[str] = None, # "20"
|
||||
client_refactor: Optional[str] = None, # "true"
|
||||
vector_field: str = Field.VECTOR.value,
|
||||
filters: Optional[list[dict]] = None,
|
||||
filter_type: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
if filters is not None:
|
||||
filter_type = "post_filter" if filter_type is None else filter_type
|
||||
if not isinstance(filter, list):
|
||||
raise RuntimeError(f"unexpected filter with {type(filters)}")
|
||||
final_ext = {"lvector": {}}
|
||||
if min_score != "0.0":
|
||||
final_ext["lvector"]["min_score"] = min_score
|
||||
if ef_search:
|
||||
final_ext["lvector"]["ef_search"] = ef_search
|
||||
if nprobe:
|
||||
final_ext["lvector"]["nprobe"] = nprobe
|
||||
if reorder_factor:
|
||||
final_ext["lvector"]["reorder_factor"] = reorder_factor
|
||||
if client_refactor:
|
||||
final_ext["lvector"]["client_refactor"] = client_refactor
|
||||
|
||||
search_query = {
|
||||
"size": k,
|
||||
"_source": True, # force return '_source'
|
||||
"query": {"knn": {vector_field: {"vector": query_vector, "k": k}}},
|
||||
}
|
||||
|
||||
if filters is not None:
|
||||
# when using filter, transform filter from List[Dict] to Dict as valid format
|
||||
filters = {"bool": {"must": filters}} if len(filters) > 1 else filters[0]
|
||||
search_query["query"]["knn"][vector_field]["filter"] = filters # filter should be Dict
|
||||
if filter_type:
|
||||
final_ext["lvector"]["filter_type"] = filter_type
|
||||
|
||||
if final_ext != {"lvector": {}}:
|
||||
search_query["ext"] = final_ext
|
||||
return search_query
|
||||
|
||||
|
||||
class LindormVectorStoreFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.LINDORM, collection_name))
|
||||
lindorm_config = LindormVectorStoreConfig(
|
||||
hosts=dify_config.LINDORM_URL,
|
||||
username=dify_config.LINDORM_USERNAME,
|
||||
password=dify_config.LINDORM_PASSWORD,
|
||||
)
|
||||
return LindormVectorStore(collection_name, lindorm_config)
|
||||
@@ -86,7 +86,7 @@ class MilvusVector(BaseVector):
|
||||
ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
|
||||
pks.extend(ids)
|
||||
except MilvusException as e:
|
||||
logger.error("Failed to insert batch starting at entity: %s/%s", i, total_count)
|
||||
logger.exception("Failed to insert batch starting at entity: %s/%s", i, total_count)
|
||||
raise e
|
||||
return pks
|
||||
|
||||
|
||||
@@ -142,7 +142,7 @@ class MyScaleVector(BaseVector):
|
||||
for r in self._client.query(sql).named_results()
|
||||
]
|
||||
except Exception as e:
|
||||
logging.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
|
||||
logging.exception(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
|
||||
return []
|
||||
|
||||
def delete(self) -> None:
|
||||
|
||||
@@ -129,7 +129,7 @@ class OpenSearchVector(BaseVector):
|
||||
if status == 404:
|
||||
logger.warning(f"Document not found for deletion: {doc_id}")
|
||||
else:
|
||||
logger.error(f"Error deleting document: {error}")
|
||||
logger.exception(f"Error deleting document: {error}")
|
||||
|
||||
def delete(self) -> None:
|
||||
self._client.indices.delete(index=self._collection_name.lower())
|
||||
@@ -158,7 +158,7 @@ class OpenSearchVector(BaseVector):
|
||||
try:
|
||||
response = self._client.search(index=self._collection_name.lower(), body=query)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing search: {e}")
|
||||
logger.exception(f"Error executing search: {e}")
|
||||
raise
|
||||
|
||||
docs = []
|
||||
|
||||
@@ -134,6 +134,10 @@ class Vector:
|
||||
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import TidbOnQdrantVectorFactory
|
||||
|
||||
return TidbOnQdrantVectorFactory
|
||||
case VectorType.LINDORM:
|
||||
from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStoreFactory
|
||||
|
||||
return LindormVectorStoreFactory
|
||||
case VectorType.OCEANBASE:
|
||||
from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ class VectorType(str, Enum):
|
||||
TENCENT = "tencent"
|
||||
ORACLE = "oracle"
|
||||
ELASTICSEARCH = "elasticsearch"
|
||||
LINDORM = "lindorm"
|
||||
COUCHBASE = "couchbase"
|
||||
BAIDU = "baidu"
|
||||
VIKINGDB = "vikingdb"
|
||||
|
||||
@@ -89,7 +89,7 @@ class CacheEmbedding(Embeddings):
|
||||
db.session.rollback()
|
||||
except Exception as ex:
|
||||
db.session.rollback()
|
||||
logger.error("Failed to embed documents: %s", ex)
|
||||
logger.exception("Failed to embed documents: %s", ex)
|
||||
raise ex
|
||||
|
||||
return text_embeddings
|
||||
|
||||
@@ -14,6 +14,7 @@ import requests
|
||||
from docx import Document as DocxDocument
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import ssrf_proxy
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
@@ -27,7 +28,6 @@ logger = logging.getLogger(__name__)
|
||||
class WordExtractor(BaseExtractor):
|
||||
"""Load docx files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
@@ -50,9 +50,9 @@ class WordExtractor(BaseExtractor):
|
||||
|
||||
self.web_path = self.file_path
|
||||
# TODO: use a better way to handle the file
|
||||
self.temp_file = tempfile.NamedTemporaryFile() # noqa: SIM115
|
||||
self.temp_file.write(r.content)
|
||||
self.file_path = self.temp_file.name
|
||||
with tempfile.NamedTemporaryFile(delete=False) as self.temp_file:
|
||||
self.temp_file.write(r.content)
|
||||
self.file_path = self.temp_file.name
|
||||
elif not os.path.isfile(self.file_path):
|
||||
raise ValueError(f"File path {self.file_path} is not a valid file or url")
|
||||
|
||||
@@ -86,7 +86,7 @@ class WordExtractor(BaseExtractor):
|
||||
image_count += 1
|
||||
if rel.is_external:
|
||||
url = rel.reltype
|
||||
response = requests.get(url, stream=True)
|
||||
response = ssrf_proxy.get(url, stream=True)
|
||||
if response.status_code == 200:
|
||||
image_ext = mimetypes.guess_extension(response.headers["Content-Type"])
|
||||
file_uuid = str(uuid.uuid4())
|
||||
@@ -229,7 +229,7 @@ class WordExtractor(BaseExtractor):
|
||||
for i in url_pattern.findall(x.text):
|
||||
hyperlinks_url = str(i)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.exception(e)
|
||||
|
||||
def parse_paragraph(paragraph):
|
||||
paragraph_content = []
|
||||
|
||||
@@ -48,6 +48,13 @@
|
||||
- feishu_task
|
||||
- feishu_calendar
|
||||
- feishu_spreadsheet
|
||||
- lark_base
|
||||
- lark_document
|
||||
- lark_message_and_group
|
||||
- lark_wiki
|
||||
- lark_task
|
||||
- lark_calendar
|
||||
- lark_spreadsheet
|
||||
- slack
|
||||
- twilio
|
||||
- wecom
|
||||
|
||||
@@ -4,7 +4,7 @@ from hmac import new as hmac_new
|
||||
from json import loads as json_loads
|
||||
from threading import Lock
|
||||
from time import sleep, time
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from httpx import get, post
|
||||
from requests import get as requests_get
|
||||
@@ -15,27 +15,27 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter,
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class AIPPTGenerateTool(BuiltinTool):
|
||||
class AIPPTGenerateToolAdapter:
|
||||
"""
|
||||
A tool for generating a ppt
|
||||
"""
|
||||
|
||||
_api_base_url = URL("https://co.aippt.cn/api")
|
||||
_api_token_cache = {}
|
||||
_api_token_cache_lock: Optional[Lock] = None
|
||||
_style_cache = {}
|
||||
_style_cache_lock: Optional[Lock] = None
|
||||
|
||||
_api_token_cache_lock = Lock()
|
||||
_style_cache_lock = Lock()
|
||||
|
||||
_task = {}
|
||||
_task_type_map = {
|
||||
"auto": 1,
|
||||
"markdown": 7,
|
||||
}
|
||||
_tool: BuiltinTool
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self._api_token_cache_lock = Lock()
|
||||
self._style_cache_lock = Lock()
|
||||
def __init__(self, tool: BuiltinTool = None):
|
||||
self._tool = tool
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
"""
|
||||
@@ -51,11 +51,11 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||
"""
|
||||
title = tool_parameters.get("title", "")
|
||||
if not title:
|
||||
return self.create_text_message("Please provide a title for the ppt")
|
||||
return self._tool.create_text_message("Please provide a title for the ppt")
|
||||
|
||||
model = tool_parameters.get("model", "aippt")
|
||||
if not model:
|
||||
return self.create_text_message("Please provide a model for the ppt")
|
||||
return self._tool.create_text_message("Please provide a model for the ppt")
|
||||
|
||||
outline = tool_parameters.get("outline", "")
|
||||
|
||||
@@ -68,8 +68,8 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||
)
|
||||
|
||||
# get suit
|
||||
color = tool_parameters.get("color")
|
||||
style = tool_parameters.get("style")
|
||||
color: str = tool_parameters.get("color")
|
||||
style: str = tool_parameters.get("style")
|
||||
|
||||
if color == "__default__":
|
||||
color_id = ""
|
||||
@@ -93,9 +93,9 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||
# generate ppt
|
||||
_, ppt_url = self._generate_ppt(task_id=task_id, suit_id=suit_id, user_id=user_id)
|
||||
|
||||
return self.create_text_message(
|
||||
return self._tool.create_text_message(
|
||||
"""the ppt has been created successfully,"""
|
||||
f"""the ppt url is {ppt_url}"""
|
||||
f"""the ppt url is {ppt_url} ."""
|
||||
"""please give the ppt url to user and direct user to download it."""
|
||||
)
|
||||
|
||||
@@ -111,8 +111,8 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||
"""
|
||||
headers = {
|
||||
"x-channel": "",
|
||||
"x-api-key": self.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
|
||||
"x-api-key": self._tool.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id),
|
||||
}
|
||||
response = post(
|
||||
str(self._api_base_url / "ai" / "chat" / "v2" / "task"),
|
||||
@@ -139,8 +139,8 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||
|
||||
headers = {
|
||||
"x-channel": "",
|
||||
"x-api-key": self.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
|
||||
"x-api-key": self._tool.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id),
|
||||
}
|
||||
|
||||
response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60))
|
||||
@@ -183,8 +183,8 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||
|
||||
headers = {
|
||||
"x-channel": "",
|
||||
"x-api-key": self.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
|
||||
"x-api-key": self._tool.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id),
|
||||
}
|
||||
|
||||
response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60))
|
||||
@@ -236,14 +236,15 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||
"""
|
||||
headers = {
|
||||
"x-channel": "",
|
||||
"x-api-key": self.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
|
||||
"x-api-key": self._tool.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id),
|
||||
}
|
||||
|
||||
response = post(
|
||||
str(self._api_base_url / "design" / "v2" / "save"),
|
||||
headers=headers,
|
||||
data={"task_id": task_id, "template_id": suit_id},
|
||||
timeout=(10, 60),
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
@@ -350,11 +351,13 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def _calculate_sign(cls, access_key: str, secret_key: str, timestamp: int) -> str:
|
||||
@staticmethod
|
||||
def _calculate_sign(access_key: str, secret_key: str, timestamp: int) -> str:
|
||||
return b64encode(
|
||||
hmac_new(
|
||||
key=secret_key.encode("utf-8"), msg=f"GET@/api/grant/token/@{timestamp}".encode(), digestmod=sha1
|
||||
key=secret_key.encode("utf-8"),
|
||||
msg=f"GET@/api/grant/token/@{timestamp}".encode(),
|
||||
digestmod=sha1,
|
||||
).digest()
|
||||
).decode("utf-8")
|
||||
|
||||
@@ -419,10 +422,12 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||
:param credentials: the credentials
|
||||
:return: Tuple[list[dict[id, color]], list[dict[id, style]]
|
||||
"""
|
||||
if not self.runtime.credentials.get("aippt_access_key") or not self.runtime.credentials.get("aippt_secret_key"):
|
||||
if not self._tool.runtime.credentials.get("aippt_access_key") or not self._tool.runtime.credentials.get(
|
||||
"aippt_secret_key"
|
||||
):
|
||||
raise Exception("Please provide aippt credentials")
|
||||
|
||||
return self._get_styles(credentials=self.runtime.credentials, user_id=user_id)
|
||||
return self._get_styles(credentials=self._tool.runtime.credentials, user_id=user_id)
|
||||
|
||||
def _get_suit(self, style_id: int, colour_id: int) -> int:
|
||||
"""
|
||||
@@ -430,8 +435,8 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||
"""
|
||||
headers = {
|
||||
"x-channel": "",
|
||||
"x-api-key": self.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id="__dify_system__"),
|
||||
"x-api-key": self._tool.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id="__dify_system__"),
|
||||
}
|
||||
response = get(
|
||||
str(self._api_base_url / "template_component" / "suit" / "search"),
|
||||
@@ -496,3 +501,18 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class AIPPTGenerateTool(BuiltinTool):
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
return AIPPTGenerateToolAdapter(self)._invoke(user_id, tool_parameters)
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
return AIPPTGenerateToolAdapter(self).get_runtime_parameters()
|
||||
|
||||
@classmethod
|
||||
def _get_api_token(cls, credentials: dict[str, str], user_id: str) -> str:
|
||||
return AIPPTGenerateToolAdapter()._get_api_token(credentials, user_id)
|
||||
|
||||
@@ -8,6 +8,20 @@ from core.tools.provider.builtin.comfyui.tools.comfyui_client import ComfyUiClie
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
def sanitize_json_string(s):
|
||||
escape_dict = {
|
||||
"\n": "\\n",
|
||||
"\r": "\\r",
|
||||
"\t": "\\t",
|
||||
"\b": "\\b",
|
||||
"\f": "\\f",
|
||||
}
|
||||
for char, escaped in escape_dict.items():
|
||||
s = s.replace(char, escaped)
|
||||
|
||||
return s
|
||||
|
||||
|
||||
class ComfyUIWorkflowTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
comfyui = ComfyUiClient(self.runtime.credentials["base_url"])
|
||||
@@ -26,13 +40,17 @@ class ComfyUIWorkflowTool(BuiltinTool):
|
||||
set_prompt_with_ksampler = True
|
||||
if "{{positive_prompt}}" in workflow:
|
||||
set_prompt_with_ksampler = False
|
||||
workflow = workflow.replace("{{positive_prompt}}", positive_prompt)
|
||||
workflow = workflow.replace("{{negative_prompt}}", negative_prompt)
|
||||
workflow = workflow.replace("{{positive_prompt}}", positive_prompt.replace('"', "'"))
|
||||
workflow = workflow.replace("{{negative_prompt}}", negative_prompt.replace('"', "'"))
|
||||
|
||||
try:
|
||||
prompt = json.loads(workflow)
|
||||
except:
|
||||
return self.create_text_message("the Workflow JSON is not correct")
|
||||
except json.JSONDecodeError:
|
||||
cleaned_string = sanitize_json_string(workflow)
|
||||
try:
|
||||
prompt = json.loads(cleaned_string)
|
||||
except:
|
||||
return self.create_text_message("the Workflow JSON is not correct")
|
||||
|
||||
if set_prompt_with_ksampler:
|
||||
try:
|
||||
|
||||
@@ -34,7 +34,7 @@ parameters:
|
||||
Page size, default value: 20, maximum value: 100.
|
||||
zh_Hans: 分页大小,默认值:20,最大值:100。
|
||||
llm_description: 分页大小,默认值:20,最大值:100。
|
||||
form: llm
|
||||
form: form
|
||||
|
||||
- name: page_token
|
||||
type: string
|
||||
|
||||
@@ -147,7 +147,7 @@ parameters:
|
||||
Page size, default value: 20, maximum value: 500.
|
||||
zh_Hans: 分页大小,默认值:20,最大值:500。
|
||||
llm_description: 分页大小,默认值:20,最大值:500。
|
||||
form: llm
|
||||
form: form
|
||||
|
||||
- name: page_token
|
||||
type: string
|
||||
|
||||
@@ -47,7 +47,7 @@ parameters:
|
||||
en_US: The page size, i.e., the number of data entries returned in a single request. The default value is 50, and the value range is [50,1000].
|
||||
zh_Hans: 分页大小,即单次请求所返回的数据条目数。默认值为 50,取值范围为 [50,1000]。
|
||||
llm_description: 分页大小,即单次请求所返回的数据条目数。默认值为 50,取值范围为 [50,1000]。
|
||||
form: llm
|
||||
form: form
|
||||
|
||||
- name: page_token
|
||||
type: string
|
||||
|
||||
@@ -85,7 +85,7 @@ parameters:
|
||||
en_US: The page size, i.e., the number of data entries returned in a single request. The default value is 20, and the value range is [10,100].
|
||||
zh_Hans: 分页大小,即单次请求所返回的数据条目数。默认值为 20,取值范围为 [10,100]。
|
||||
llm_description: 分页大小,即单次请求所返回的数据条目数。默认值为 20,取值范围为 [10,100]。
|
||||
form: llm
|
||||
form: form
|
||||
|
||||
- name: page_token
|
||||
type: string
|
||||
|
||||
@@ -59,7 +59,7 @@ parameters:
|
||||
en_US: Paging size, the default and maximum value is 500.
|
||||
zh_Hans: 分页大小, 默认值和最大值为 500。
|
||||
llm_description: 分页大小, 表示一次请求最多返回多少条数据,默认值和最大值为 500。
|
||||
form: llm
|
||||
form: form
|
||||
|
||||
- name: page_token
|
||||
type: string
|
||||
|
||||
@@ -81,7 +81,7 @@ parameters:
|
||||
en_US: The page size, i.e., the number of data entries returned in a single request. The default value is 20, and the value range is [1,50].
|
||||
zh_Hans: 分页大小,即单次请求所返回的数据条目数。默认值为 20,取值范围为 [1,50]。
|
||||
llm_description: 分页大小,即单次请求所返回的数据条目数。默认值为 20,取值范围为 [1,50]。
|
||||
form: llm
|
||||
form: form
|
||||
|
||||
- name: page_token
|
||||
type: string
|
||||
|
||||
@@ -57,7 +57,7 @@ parameters:
|
||||
en_US: The page size, i.e., the number of data entries returned in a single request. The default value is 20, and the value range is [1,50].
|
||||
zh_Hans: 分页大小,即单次请求所返回的数据条目数。默认值为 20,取值范围为 [1,50]。
|
||||
llm_description: 分页大小,即单次请求所返回的数据条目数。默认值为 20,取值范围为 [1,50]。
|
||||
form: llm
|
||||
form: form
|
||||
|
||||
- name: page_token
|
||||
type: string
|
||||
|
||||
@@ -56,7 +56,7 @@ parameters:
|
||||
en_US: Number of columns to add, range (0-5000].
|
||||
zh_Hans: 要增加的列数,范围(0-5000]。
|
||||
llm_description: 要增加的列数,范围(0-5000]。
|
||||
form: llm
|
||||
form: form
|
||||
|
||||
- name: values
|
||||
type: string
|
||||
|
||||
@@ -56,7 +56,7 @@ parameters:
|
||||
en_US: Number of rows to add, range (0-5000].
|
||||
zh_Hans: 要增加行数,范围(0-5000]。
|
||||
llm_description: 要增加行数,范围(0-5000]。
|
||||
form: llm
|
||||
form: form
|
||||
|
||||
- name: values
|
||||
type: string
|
||||
|
||||
@@ -82,7 +82,7 @@ parameters:
|
||||
en_US: Starting column number, starting from 1.
|
||||
zh_Hans: 起始列号,从 1 开始。
|
||||
llm_description: 起始列号,从 1 开始。
|
||||
form: llm
|
||||
form: form
|
||||
|
||||
- name: num_cols
|
||||
type: number
|
||||
@@ -94,4 +94,4 @@ parameters:
|
||||
en_US: Number of columns to read.
|
||||
zh_Hans: 读取列数
|
||||
llm_description: 读取列数
|
||||
form: llm
|
||||
form: form
|
||||
|
||||
@@ -82,7 +82,7 @@ parameters:
|
||||
en_US: Starting row number, starting from 1.
|
||||
zh_Hans: 起始行号,从 1 开始。
|
||||
llm_description: 起始行号,从 1 开始。
|
||||
form: llm
|
||||
form: form
|
||||
|
||||
- name: num_rows
|
||||
type: number
|
||||
@@ -94,4 +94,4 @@ parameters:
|
||||
en_US: Number of rows to read.
|
||||
zh_Hans: 读取行数
|
||||
llm_description: 读取行数
|
||||
form: llm
|
||||
form: form
|
||||
|
||||
@@ -82,7 +82,7 @@ parameters:
|
||||
en_US: Starting row number, starting from 1.
|
||||
zh_Hans: 起始行号,从 1 开始。
|
||||
llm_description: 起始行号,从 1 开始。
|
||||
form: llm
|
||||
form: form
|
||||
|
||||
- name: num_rows
|
||||
type: number
|
||||
@@ -94,7 +94,7 @@ parameters:
|
||||
en_US: Number of rows to read.
|
||||
zh_Hans: 读取行数
|
||||
llm_description: 读取行数
|
||||
form: llm
|
||||
form: form
|
||||
|
||||
- name: range
|
||||
type: string
|
||||
|
||||
@@ -36,7 +36,7 @@ parameters:
|
||||
en_US: The size of each page, with a maximum value of 50.
|
||||
zh_Hans: 分页大小,最大值 50。
|
||||
llm_description: 分页大小,最大值 50。
|
||||
form: llm
|
||||
form: form
|
||||
|
||||
- name: page_token
|
||||
type: string
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
<svg width="40" height="40" viewBox="0 0 40 40" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M25.132 24.3947C25.497 25.7527 25.8984 27.1413 26.3334 28.5834C26.7302 29.8992 25.5459 30.4167 25.0752 29.1758C24.571 27.8466 24.0885 26.523 23.6347 25.1729C21.065 26.4654 18.5025 27.5424 15.5961 28.7541C16.7581 33.0256 17.8309 36.5984 19.4952 39.9935C19.4953 39.9936 19.4953 39.9937 19.4954 39.9938C19.6631 39.9979 19.8313 40 20 40C31.0457 40 40 31.0457 40 20C40 16.0335 38.8453 12.3366 36.8537 9.22729C31.6585 9.69534 27.0513 10.4562 22.8185 11.406C22.8882 12.252 22.9677 13.0739 23.0555 13.855C23.3824 16.7604 23.9112 19.5281 24.6137 22.3836C27.0581 21.2848 29.084 20.3225 30.6816 19.522C32.2154 18.7535 33.6943 18.7062 31.2018 20.6594C29.0388 22.1602 27.0644 23.3566 25.132 24.3947ZM36.1559 8.20846C33.0001 3.89184 28.1561 0.887462 22.5955 0.166882C22.4257 2.86234 22.4785 6.26344 22.681 9.50447C26.7473 8.88859 31.1721 8.46032 36.1559 8.20846ZM19.9369 9.73661e-05C19.7594 2.92694 19.8384 6.65663 20.19 9.91293C17.3748 10.4109 14.7225 11.0064 12.1592 11.7038C12.0486 10.4257 11.9927 9.25764 11.9927 8.24178C11.9927 7.5054 11.3957 6.90844 10.6593 6.90844C9.92296 6.90844 9.32601 7.5054 9.32601 8.24178C9.32601 9.47868 9.42873 10.898 9.61402 12.438C8.33567 12.8278 7.07397 13.2443 5.81918 13.688C5.12493 13.9336 4.76118 14.6954 5.0067 15.3896C5.25223 16.0839 6.01406 16.4476 6.7083 16.2021C7.7931 15.8185 8.88482 15.4388 9.98927 15.0659C10.5222 18.3344 11.3344 21.9428 12.2703 25.4156C12.4336 26.0218 12.6062 26.6262 12.7863 27.2263C9.34168 28.4135 5.82612 29.3782 2.61128 29.8879C0.949407 26.9716 0 23.5967 0 20C0 8.97534 8.92023 0.0341108 19.9369 9.73661e-05ZM4.19152 32.2527C7.45069 36.4516 12.3458 39.3173 17.9204 39.8932C16.5916 37.455 14.9338 33.717 13.5405 29.5901C10.4404 30.7762 7.25883 31.6027 4.19152 32.2527ZM22.9735 23.1135C22.1479 20.41 21.4462 17.5441 20.9225 14.277C20.746 13.5841 20.5918 12.8035 20.4593 11.9636C17.6508 12.6606 14.9992 13.4372 12.4356 14.2598C12.8479 17.4766 13.5448 21.1334 14.5118 24.7218C14.662 25.2792 14.8081 25.8248 14.9514 26.3594L14.9516 26.3603L14.9524 26.3634L14.9526 26.3639L14.973 26.4401C16.1833 25.9872 17.3746 25.5123 18.53 25.0259C20.1235 24.3552 21.6051 23.7165 22.9735 23.1135Z" fill="#141519"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 2.3 KiB |
17
api/core/tools/provider/builtin/gitee_ai/gitee_ai.py
Normal file
17
api/core/tools/provider/builtin/gitee_ai/gitee_ai.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import requests
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class GiteeAIProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
url = "https://ai.gitee.com/api/base/account/me"
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
"authorization": f"Bearer {credentials.get('api_key')}",
|
||||
}
|
||||
|
||||
response = requests.get(url, headers=headers)
|
||||
if response.status_code != 200:
|
||||
raise ToolProviderCredentialValidationError("GiteeAI API key is invalid")
|
||||
22
api/core/tools/provider/builtin/gitee_ai/gitee_ai.yaml
Normal file
22
api/core/tools/provider/builtin/gitee_ai/gitee_ai.yaml
Normal file
@@ -0,0 +1,22 @@
|
||||
identity:
|
||||
author: Gitee AI
|
||||
name: gitee_ai
|
||||
label:
|
||||
en_US: Gitee AI
|
||||
zh_Hans: Gitee AI
|
||||
description:
|
||||
en_US: 快速体验大模型,领先探索 AI 开源世界
|
||||
zh_Hans: 快速体验大模型,领先探索 AI 开源世界
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- image
|
||||
credentials_for_provider:
|
||||
api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: API Key
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
url: https://ai.gitee.com/dashboard/settings/tokens
|
||||
@@ -0,0 +1,33 @@
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class GiteeAITool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
headers = {
|
||||
"content-type": "application/json",
|
||||
"authorization": f"Bearer {self.runtime.credentials['api_key']}",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"inputs": tool_parameters.get("inputs"),
|
||||
"width": tool_parameters.get("width", "720"),
|
||||
"height": tool_parameters.get("height", "720"),
|
||||
}
|
||||
model = tool_parameters.get("model", "Kolors")
|
||||
url = f"https://ai.gitee.com/api/serverless/{model}/text-to-image"
|
||||
|
||||
response = requests.post(url, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
return self.create_text_message(f"Got Error Response:{response.text}")
|
||||
|
||||
# The returned image is base64 and needs to be mark as an image
|
||||
result = [self.create_blob_message(blob=response.content, meta={"mime_type": "image/jpeg"})]
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,72 @@
|
||||
identity:
|
||||
name: text to image
|
||||
author: gitee_ai
|
||||
label:
|
||||
en_US: text to image
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: generate images using a variety of popular models
|
||||
llm: This tool is used to generate image from text.
|
||||
parameters:
|
||||
- name: model
|
||||
type: select
|
||||
required: true
|
||||
options:
|
||||
- value: flux-1-schnell
|
||||
label:
|
||||
en_US: flux-1-schnell
|
||||
- value: Kolors
|
||||
label:
|
||||
en_US: Kolors
|
||||
- value: stable-diffusion-3-medium
|
||||
label:
|
||||
en_US: stable-diffusion-3-medium
|
||||
- value: stable-diffusion-xl-base-1.0
|
||||
label:
|
||||
en_US: stable-diffusion-xl-base-1.0
|
||||
- value: stable-diffusion-v1-4
|
||||
label:
|
||||
en_US: stable-diffusion-v1-4
|
||||
default: Kolors
|
||||
label:
|
||||
en_US: Choose Image Model
|
||||
zh_Hans: 选择生成图片的模型
|
||||
form: form
|
||||
- name: inputs
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Input Text
|
||||
zh_Hans: 输入文本
|
||||
human_description:
|
||||
en_US: The text input used to generate the image.
|
||||
zh_Hans: 用于生成图片的输入文本。
|
||||
llm_description: This text input will be used to generate image.
|
||||
form: llm
|
||||
- name: width
|
||||
type: number
|
||||
required: true
|
||||
default: 720
|
||||
min: 1
|
||||
max: 1024
|
||||
label:
|
||||
en_US: Image Width
|
||||
zh_Hans: 图片宽度
|
||||
human_description:
|
||||
en_US: The width of the generated image.
|
||||
zh_Hans: 生成图片的宽度。
|
||||
form: form
|
||||
- name: height
|
||||
type: number
|
||||
required: true
|
||||
default: 720
|
||||
min: 1
|
||||
max: 1024
|
||||
label:
|
||||
en_US: Image Height
|
||||
zh_Hans: 图片高度
|
||||
human_description:
|
||||
en_US: The height of the generated image.
|
||||
zh_Hans: 生成图片的高度。
|
||||
form: form
|
||||
@@ -13,15 +13,15 @@ class GitlabCommitsTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
project = tool_parameters.get("project", "")
|
||||
branch = tool_parameters.get("branch", "")
|
||||
repository = tool_parameters.get("repository", "")
|
||||
employee = tool_parameters.get("employee", "")
|
||||
start_time = tool_parameters.get("start_time", "")
|
||||
end_time = tool_parameters.get("end_time", "")
|
||||
change_type = tool_parameters.get("change_type", "all")
|
||||
|
||||
if not project and not repository:
|
||||
return self.create_text_message("Either project or repository is required")
|
||||
if not repository:
|
||||
return self.create_text_message("Either repository is required")
|
||||
|
||||
if not start_time:
|
||||
start_time = (datetime.utcnow() - timedelta(days=1)).isoformat()
|
||||
@@ -37,14 +37,9 @@ class GitlabCommitsTool(BuiltinTool):
|
||||
site_url = "https://gitlab.com"
|
||||
|
||||
# Get commit content
|
||||
if repository:
|
||||
result = self.fetch_commits(
|
||||
site_url, access_token, repository, employee, start_time, end_time, change_type, is_repository=True
|
||||
)
|
||||
else:
|
||||
result = self.fetch_commits(
|
||||
site_url, access_token, project, employee, start_time, end_time, change_type, is_repository=False
|
||||
)
|
||||
result = self.fetch_commits(
|
||||
site_url, access_token, repository, branch, employee, start_time, end_time, change_type, is_repository=True
|
||||
)
|
||||
|
||||
return [self.create_json_message(item) for item in result]
|
||||
|
||||
@@ -52,7 +47,8 @@ class GitlabCommitsTool(BuiltinTool):
|
||||
self,
|
||||
site_url: str,
|
||||
access_token: str,
|
||||
identifier: str,
|
||||
repository: str,
|
||||
branch: str,
|
||||
employee: str,
|
||||
start_time: str,
|
||||
end_time: str,
|
||||
@@ -64,27 +60,14 @@ class GitlabCommitsTool(BuiltinTool):
|
||||
results = []
|
||||
|
||||
try:
|
||||
if is_repository:
|
||||
# URL encode the repository path
|
||||
encoded_identifier = urllib.parse.quote(identifier, safe="")
|
||||
commits_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/commits"
|
||||
else:
|
||||
# Get all projects
|
||||
url = f"{domain}/api/v4/projects"
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
projects = response.json()
|
||||
|
||||
filtered_projects = [p for p in projects if identifier == "*" or p["name"] == identifier]
|
||||
|
||||
for project in filtered_projects:
|
||||
project_id = project["id"]
|
||||
project_name = project["name"]
|
||||
print(f"Project: {project_name}")
|
||||
|
||||
commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits"
|
||||
# URL encode the repository path
|
||||
encoded_repository = urllib.parse.quote(repository, safe="")
|
||||
commits_url = f"{domain}/api/v4/projects/{encoded_repository}/repository/commits"
|
||||
|
||||
# Fetch commits for the repository
|
||||
params = {"since": start_time, "until": end_time}
|
||||
if branch:
|
||||
params["ref_name"] = branch
|
||||
if employee:
|
||||
params["author"] = employee
|
||||
|
||||
@@ -96,10 +79,7 @@ class GitlabCommitsTool(BuiltinTool):
|
||||
commit_sha = commit["id"]
|
||||
author_name = commit["author_name"]
|
||||
|
||||
if is_repository:
|
||||
diff_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/commits/{commit_sha}/diff"
|
||||
else:
|
||||
diff_url = f"{domain}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/diff"
|
||||
diff_url = f"{domain}/api/v4/projects/{encoded_repository}/repository/commits/{commit_sha}/diff"
|
||||
|
||||
diff_response = requests.get(diff_url, headers=headers)
|
||||
diff_response.raise_for_status()
|
||||
@@ -120,7 +100,14 @@ class GitlabCommitsTool(BuiltinTool):
|
||||
if line.startswith("+") and not line.startswith("+++")
|
||||
]
|
||||
)
|
||||
results.append({"commit_sha": commit_sha, "author_name": author_name, "diff": final_code})
|
||||
results.append(
|
||||
{
|
||||
"diff_url": diff_url,
|
||||
"commit_sha": commit_sha,
|
||||
"author_name": author_name,
|
||||
"diff": final_code,
|
||||
}
|
||||
)
|
||||
else:
|
||||
if total_changes > 1:
|
||||
final_code = "".join(
|
||||
@@ -134,7 +121,12 @@ class GitlabCommitsTool(BuiltinTool):
|
||||
)
|
||||
final_code_escaped = json.dumps(final_code)[1:-1] # Escape the final code
|
||||
results.append(
|
||||
{"commit_sha": commit_sha, "author_name": author_name, "diff": final_code_escaped}
|
||||
{
|
||||
"diff_url": diff_url,
|
||||
"commit_sha": commit_sha,
|
||||
"author_name": author_name,
|
||||
"diff": final_code_escaped,
|
||||
}
|
||||
)
|
||||
except requests.RequestException as e:
|
||||
print(f"Error fetching data from GitLab: {e}")
|
||||
|
||||
@@ -23,7 +23,7 @@ parameters:
|
||||
form: llm
|
||||
- name: repository
|
||||
type: string
|
||||
required: false
|
||||
required: true
|
||||
label:
|
||||
en_US: repository
|
||||
zh_Hans: 仓库路径
|
||||
@@ -32,16 +32,16 @@ parameters:
|
||||
zh_Hans: 仓库路径,以namespace/project_name的形式。
|
||||
llm_description: Repository path for GitLab, like namespace/project_name.
|
||||
form: llm
|
||||
- name: project
|
||||
- name: branch
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: project
|
||||
zh_Hans: 项目名
|
||||
en_US: branch
|
||||
zh_Hans: 分支名
|
||||
human_description:
|
||||
en_US: project
|
||||
zh_Hans: 项目名
|
||||
llm_description: project for GitLab
|
||||
en_US: branch
|
||||
zh_Hans: 分支名
|
||||
llm_description: branch for GitLab
|
||||
form: llm
|
||||
- name: start_time
|
||||
type: string
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user