mirror of
https://gitee.com/dify_ai/dify.git
synced 2025-12-07 11:55:44 +08:00
Compare commits
3 Commits
refactor/t
...
fix/knowle
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
23086a2d02 | ||
|
|
7083804d48 | ||
|
|
7aa0148df2 |
@@ -0,0 +1,45 @@
|
||||
import type { FC } from 'react'
|
||||
import { createContext, useEffect, useRef } from 'react'
|
||||
import { createDatasetsDetailStore } from './store'
|
||||
import type { CommonNodeType, Node } from '../types'
|
||||
import { BlockEnum } from '../types'
|
||||
import type { KnowledgeRetrievalNodeType } from '../nodes/knowledge-retrieval/types'
|
||||
|
||||
type DatasetsDetailStoreApi = ReturnType<typeof createDatasetsDetailStore>
|
||||
|
||||
type DatasetsDetailContextType = DatasetsDetailStoreApi | undefined
|
||||
|
||||
export const DatasetsDetailContext = createContext<DatasetsDetailContextType>(undefined)
|
||||
|
||||
type DatasetsDetailProviderProps = {
|
||||
nodes: Node[]
|
||||
children: React.ReactNode
|
||||
}
|
||||
|
||||
const DatasetsDetailProvider: FC<DatasetsDetailProviderProps> = ({
|
||||
nodes,
|
||||
children,
|
||||
}) => {
|
||||
const storeRef = useRef<DatasetsDetailStoreApi>()
|
||||
|
||||
if (!storeRef.current)
|
||||
storeRef.current = createDatasetsDetailStore()
|
||||
|
||||
useEffect(() => {
|
||||
if (!storeRef.current) return
|
||||
const knowledgeRetrievalNodes = nodes.filter(node => node.data.type === BlockEnum.KnowledgeRetrieval)
|
||||
const allDatasetIds = knowledgeRetrievalNodes.reduce<string[]>((acc, node) => {
|
||||
return Array.from(new Set([...acc, ...(node.data as CommonNodeType<KnowledgeRetrievalNodeType>).dataset_ids]))
|
||||
}, [])
|
||||
storeRef.current.getState().updateDatasetsDetail(allDatasetIds)
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [])
|
||||
|
||||
return (
|
||||
<DatasetsDetailContext.Provider value={storeRef.current!}>
|
||||
{children}
|
||||
</DatasetsDetailContext.Provider>
|
||||
)
|
||||
}
|
||||
|
||||
export default DatasetsDetailProvider
|
||||
28
web/app/components/workflow/datasets-detail-store/store.ts
Normal file
28
web/app/components/workflow/datasets-detail-store/store.ts
Normal file
@@ -0,0 +1,28 @@
|
||||
import { useContext } from 'react'
|
||||
import { createStore, useStore } from 'zustand'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
import { DatasetsDetailContext } from './provider'
|
||||
import { fetchDatasets } from '@/service/datasets'
|
||||
|
||||
type DatasetsDetailStore = {
|
||||
datasetsDetail: DataSet[]
|
||||
updateDatasetsDetail: (allDatasetIds: string[]) => Promise<void>
|
||||
}
|
||||
|
||||
export const createDatasetsDetailStore = () => {
|
||||
return createStore<DatasetsDetailStore>(set => ({
|
||||
datasetsDetail: [],
|
||||
updateDatasetsDetail: async (allDatasetIds) => {
|
||||
const { data: dataSetsWithDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: allDatasetIds } })
|
||||
set({ datasetsDetail: dataSetsWithDetail })
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
export const useDatasetsDetailStore = <T>(selector: (state: DatasetsDetailStore) => T): T => {
|
||||
const store = useContext(DatasetsDetailContext)
|
||||
if (!store)
|
||||
throw new Error('Missing DatasetsDetailContext.Provider in the tree')
|
||||
|
||||
return useStore(store, selector)
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import {
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useStoreApi } from 'reactflow'
|
||||
import type {
|
||||
CommonNodeType,
|
||||
Edge,
|
||||
Node,
|
||||
} from '../types'
|
||||
@@ -27,6 +28,8 @@ import { useGetLanguage } from '@/context/i18n'
|
||||
import type { AgentNodeType } from '../nodes/agent/types'
|
||||
import { useStrategyProviders } from '@/service/use-strategy'
|
||||
import { canFindTool } from '@/utils'
|
||||
import { useDatasetsDetailStore } from '../datasets-detail-store/store'
|
||||
import type { KnowledgeRetrievalNodeType } from '../nodes/knowledge-retrieval/types'
|
||||
|
||||
export const useChecklist = (nodes: Node[], edges: Edge[]) => {
|
||||
const { t } = useTranslation()
|
||||
@@ -37,6 +40,7 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
|
||||
const customTools = useStore(s => s.customTools)
|
||||
const workflowTools = useStore(s => s.workflowTools)
|
||||
const { data: strategyProviders } = useStrategyProviders()
|
||||
const datasetsDetail = useDatasetsDetailStore(s => s.datasetsDetail)
|
||||
|
||||
const needWarningNodes = useMemo(() => {
|
||||
const list = []
|
||||
@@ -75,7 +79,15 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
|
||||
}
|
||||
|
||||
if (node.type === CUSTOM_NODE) {
|
||||
const { errorMessage } = nodesExtraData[node.data.type].checkValid(node.data, t, moreDataForCheckValid)
|
||||
let checkData = node.data
|
||||
if (node.data.type === BlockEnum.KnowledgeRetrieval) {
|
||||
const _datasets = datasetsDetail.filter(dataset => (node.data as CommonNodeType<KnowledgeRetrievalNodeType>).dataset_ids.includes(dataset.id))
|
||||
checkData = {
|
||||
...node.data,
|
||||
_datasets,
|
||||
} as CommonNodeType<KnowledgeRetrievalNodeType>
|
||||
}
|
||||
const { errorMessage } = nodesExtraData[node.data.type].checkValid(checkData, t, moreDataForCheckValid)
|
||||
|
||||
if (errorMessage || !validNodes.find(n => n.id === node.id)) {
|
||||
list.push({
|
||||
@@ -109,7 +121,7 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
|
||||
}
|
||||
|
||||
return list
|
||||
}, [nodes, edges, isChatMode, buildInTools, customTools, workflowTools, language, nodesExtraData, t, strategyProviders])
|
||||
}, [nodes, edges, isChatMode, buildInTools, customTools, workflowTools, language, nodesExtraData, t, strategyProviders, datasetsDetail])
|
||||
|
||||
return needWarningNodes
|
||||
}
|
||||
|
||||
@@ -99,6 +99,7 @@ import { useEventEmitterContextContext } from '@/context/event-emitter'
|
||||
import Confirm from '@/app/components/base/confirm'
|
||||
import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants'
|
||||
import { fetchFileUploadConfig } from '@/service/common'
|
||||
import DatasetsDetailProvider from './datasets-detail-store/provider'
|
||||
|
||||
const nodeTypes = {
|
||||
[CUSTOM_NODE]: CustomNode,
|
||||
@@ -448,11 +449,13 @@ const WorkflowWrap = memo(() => {
|
||||
nodes={nodesData}
|
||||
edges={edgesData} >
|
||||
<FeaturesProvider features={initialFeatures}>
|
||||
<Workflow
|
||||
nodes={nodesData}
|
||||
edges={edgesData}
|
||||
viewport={data?.graph.viewport}
|
||||
/>
|
||||
<DatasetsDetailProvider nodes={nodesData}>
|
||||
<Workflow
|
||||
nodes={nodesData}
|
||||
edges={edgesData}
|
||||
viewport={data?.graph.viewport}
|
||||
/>
|
||||
</DatasetsDetailProvider>
|
||||
</FeaturesProvider>
|
||||
</WorkflowHistoryProvider>
|
||||
</ReactFlowProvider>
|
||||
|
||||
@@ -41,6 +41,7 @@ import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-s
|
||||
import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import useAvailableVarList from '@/app/components/workflow/nodes/_base/hooks/use-available-var-list'
|
||||
import { useDatasetsDetailStore } from '../../datasets-detail-store/store'
|
||||
|
||||
const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||
const { nodesReadOnly: readOnly } = useNodesReadOnly()
|
||||
@@ -49,6 +50,8 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||
const startNode = getBeforeNodesInSameBranch(id).find(node => node.data.type === BlockEnum.Start)
|
||||
const startNodeId = startNode?.id
|
||||
const { inputs, setInputs: doSetInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload)
|
||||
const datasetsDetail = useDatasetsDetailStore(s => s.datasetsDetail)
|
||||
const updateDatasetsDetail = useDatasetsDetailStore(s => s.updateDatasetsDetail)
|
||||
|
||||
const inputRef = useRef(inputs)
|
||||
|
||||
@@ -218,15 +221,12 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||
(async () => {
|
||||
const inputs = inputRef.current
|
||||
const datasetIds = inputs.dataset_ids
|
||||
let _datasets = selectedDatasets
|
||||
if (datasetIds?.length > 0) {
|
||||
const { data: dataSetsWithDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: datasetIds } as any })
|
||||
_datasets = dataSetsWithDetail
|
||||
setSelectedDatasets(dataSetsWithDetail)
|
||||
}
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
draft.dataset_ids = datasetIds
|
||||
draft._datasets = _datasets
|
||||
})
|
||||
setInputs(newInputs)
|
||||
setSelectedDatasetsLoaded(true)
|
||||
@@ -256,7 +256,6 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||
} = getSelectedDatasetsMode(newDatasets)
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
draft.dataset_ids = newDatasets.map(d => d.id)
|
||||
draft._datasets = newDatasets
|
||||
|
||||
if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) {
|
||||
const multipleRetrievalConfig = draft.multiple_retrieval_config
|
||||
@@ -266,6 +265,15 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||
})
|
||||
}
|
||||
})
|
||||
const allDatasetIds = datasetsDetail.map(d => d.id)
|
||||
const newAllDatasetIds = produce(allDatasetIds, (draft) => {
|
||||
const newDatasetIds = newDatasets.map(d => d.id)
|
||||
newDatasetIds.forEach((id) => {
|
||||
if (!draft.includes(id))
|
||||
draft.push(id)
|
||||
})
|
||||
})
|
||||
updateDatasetsDetail(newAllDatasetIds)
|
||||
setInputs(newInputs)
|
||||
setSelectedDatasets(newDatasets)
|
||||
|
||||
@@ -275,7 +283,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||
|| allExternal
|
||||
)
|
||||
setRerankModelOpen(true)
|
||||
}, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel, currentRerankProvider])
|
||||
}, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel, currentRerankProvider, datasetsDetail, updateDatasetsDetail])
|
||||
|
||||
const filterVar = useCallback((varPayload: Var) => {
|
||||
return varPayload.type === VarType.string
|
||||
|
||||
Reference in New Issue
Block a user