Compare commits

...

3 Commits

5 changed files with 108 additions and 12 deletions

View File

@@ -0,0 +1,45 @@
import type { FC } from 'react'
import { createContext, useEffect, useRef } from 'react'
import { createDatasetsDetailStore } from './store'
import type { CommonNodeType, Node } from '../types'
import { BlockEnum } from '../types'
import type { KnowledgeRetrievalNodeType } from '../nodes/knowledge-retrieval/types'
type DatasetsDetailStoreApi = ReturnType<typeof createDatasetsDetailStore>
type DatasetsDetailContextType = DatasetsDetailStoreApi | undefined
export const DatasetsDetailContext = createContext<DatasetsDetailContextType>(undefined)
type DatasetsDetailProviderProps = {
nodes: Node[]
children: React.ReactNode
}
const DatasetsDetailProvider: FC<DatasetsDetailProviderProps> = ({
nodes,
children,
}) => {
const storeRef = useRef<DatasetsDetailStoreApi>()
if (!storeRef.current)
storeRef.current = createDatasetsDetailStore()
useEffect(() => {
if (!storeRef.current) return
const knowledgeRetrievalNodes = nodes.filter(node => node.data.type === BlockEnum.KnowledgeRetrieval)
const allDatasetIds = knowledgeRetrievalNodes.reduce<string[]>((acc, node) => {
return Array.from(new Set([...acc, ...(node.data as CommonNodeType<KnowledgeRetrievalNodeType>).dataset_ids]))
}, [])
storeRef.current.getState().updateDatasetsDetail(allDatasetIds)
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [])
return (
<DatasetsDetailContext.Provider value={storeRef.current!}>
{children}
</DatasetsDetailContext.Provider>
)
}
export default DatasetsDetailProvider

View File

@@ -0,0 +1,28 @@
import { useContext } from 'react'
import { createStore, useStore } from 'zustand'
import type { DataSet } from '@/models/datasets'
import { DatasetsDetailContext } from './provider'
import { fetchDatasets } from '@/service/datasets'
type DatasetsDetailStore = {
datasetsDetail: DataSet[]
updateDatasetsDetail: (allDatasetIds: string[]) => Promise<void>
}
export const createDatasetsDetailStore = () => {
return createStore<DatasetsDetailStore>(set => ({
datasetsDetail: [],
updateDatasetsDetail: async (allDatasetIds) => {
const { data: dataSetsWithDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: allDatasetIds } })
set({ datasetsDetail: dataSetsWithDetail })
},
}))
}
export const useDatasetsDetailStore = <T>(selector: (state: DatasetsDetailStore) => T): T => {
const store = useContext(DatasetsDetailContext)
if (!store)
throw new Error('Missing DatasetsDetailContext.Provider in the tree')
return useStore(store, selector)
}

View File

@@ -5,6 +5,7 @@ import {
import { useTranslation } from 'react-i18next'
import { useStoreApi } from 'reactflow'
import type {
CommonNodeType,
Edge,
Node,
} from '../types'
@@ -27,6 +28,8 @@ import { useGetLanguage } from '@/context/i18n'
import type { AgentNodeType } from '../nodes/agent/types'
import { useStrategyProviders } from '@/service/use-strategy'
import { canFindTool } from '@/utils'
import { useDatasetsDetailStore } from '../datasets-detail-store/store'
import type { KnowledgeRetrievalNodeType } from '../nodes/knowledge-retrieval/types'
export const useChecklist = (nodes: Node[], edges: Edge[]) => {
const { t } = useTranslation()
@@ -37,6 +40,7 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
const customTools = useStore(s => s.customTools)
const workflowTools = useStore(s => s.workflowTools)
const { data: strategyProviders } = useStrategyProviders()
const datasetsDetail = useDatasetsDetailStore(s => s.datasetsDetail)
const needWarningNodes = useMemo(() => {
const list = []
@@ -75,7 +79,15 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
}
if (node.type === CUSTOM_NODE) {
const { errorMessage } = nodesExtraData[node.data.type].checkValid(node.data, t, moreDataForCheckValid)
let checkData = node.data
if (node.data.type === BlockEnum.KnowledgeRetrieval) {
const _datasets = datasetsDetail.filter(dataset => (node.data as CommonNodeType<KnowledgeRetrievalNodeType>).dataset_ids.includes(dataset.id))
checkData = {
...node.data,
_datasets,
} as CommonNodeType<KnowledgeRetrievalNodeType>
}
const { errorMessage } = nodesExtraData[node.data.type].checkValid(checkData, t, moreDataForCheckValid)
if (errorMessage || !validNodes.find(n => n.id === node.id)) {
list.push({
@@ -109,7 +121,7 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
}
return list
}, [nodes, edges, isChatMode, buildInTools, customTools, workflowTools, language, nodesExtraData, t, strategyProviders])
}, [nodes, edges, isChatMode, buildInTools, customTools, workflowTools, language, nodesExtraData, t, strategyProviders, datasetsDetail])
return needWarningNodes
}

View File

@@ -99,6 +99,7 @@ import { useEventEmitterContextContext } from '@/context/event-emitter'
import Confirm from '@/app/components/base/confirm'
import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants'
import { fetchFileUploadConfig } from '@/service/common'
import DatasetsDetailProvider from './datasets-detail-store/provider'
const nodeTypes = {
[CUSTOM_NODE]: CustomNode,
@@ -448,11 +449,13 @@ const WorkflowWrap = memo(() => {
nodes={nodesData}
edges={edgesData} >
<FeaturesProvider features={initialFeatures}>
<Workflow
nodes={nodesData}
edges={edgesData}
viewport={data?.graph.viewport}
/>
<DatasetsDetailProvider nodes={nodesData}>
<Workflow
nodes={nodesData}
edges={edgesData}
viewport={data?.graph.viewport}
/>
</DatasetsDetailProvider>
</FeaturesProvider>
</WorkflowHistoryProvider>
</ReactFlowProvider>

View File

@@ -41,6 +41,7 @@ import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-s
import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import useAvailableVarList from '@/app/components/workflow/nodes/_base/hooks/use-available-var-list'
import { useDatasetsDetailStore } from '../../datasets-detail-store/store'
const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
const { nodesReadOnly: readOnly } = useNodesReadOnly()
@@ -49,6 +50,8 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
const startNode = getBeforeNodesInSameBranch(id).find(node => node.data.type === BlockEnum.Start)
const startNodeId = startNode?.id
const { inputs, setInputs: doSetInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload)
const datasetsDetail = useDatasetsDetailStore(s => s.datasetsDetail)
const updateDatasetsDetail = useDatasetsDetailStore(s => s.updateDatasetsDetail)
const inputRef = useRef(inputs)
@@ -218,15 +221,12 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
(async () => {
const inputs = inputRef.current
const datasetIds = inputs.dataset_ids
let _datasets = selectedDatasets
if (datasetIds?.length > 0) {
const { data: dataSetsWithDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: datasetIds } as any })
_datasets = dataSetsWithDetail
setSelectedDatasets(dataSetsWithDetail)
}
const newInputs = produce(inputs, (draft) => {
draft.dataset_ids = datasetIds
draft._datasets = _datasets
})
setInputs(newInputs)
setSelectedDatasetsLoaded(true)
@@ -256,7 +256,6 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
} = getSelectedDatasetsMode(newDatasets)
const newInputs = produce(inputs, (draft) => {
draft.dataset_ids = newDatasets.map(d => d.id)
draft._datasets = newDatasets
if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) {
const multipleRetrievalConfig = draft.multiple_retrieval_config
@@ -266,6 +265,15 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
})
}
})
const allDatasetIds = datasetsDetail.map(d => d.id)
const newAllDatasetIds = produce(allDatasetIds, (draft) => {
const newDatasetIds = newDatasets.map(d => d.id)
newDatasetIds.forEach((id) => {
if (!draft.includes(id))
draft.push(id)
})
})
updateDatasetsDetail(newAllDatasetIds)
setInputs(newInputs)
setSelectedDatasets(newDatasets)
@@ -275,7 +283,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|| allExternal
)
setRerankModelOpen(true)
}, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel, currentRerankProvider])
}, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel, currentRerankProvider, datasetsDetail, updateDatasetsDetail])
const filterVar = useCallback((varPayload: Var) => {
return varPayload.type === VarType.string