model_train#

Warning

The code reference is a work in progress and may contain inconsistencies.

socialnet.model_train(trajectory_files: Sequence[Path | str] | Path | str, output_parent: Path | str | None = None, session_name: str | None = None, model: Literal['fc', 'attention', 'interaction'] = 'attention', future_steps: int = 32, history_steps: int = 0, input_types: list[list[Literal['n', 'natt', 'gvars', 'h', 'hvars', 'hbott', 'pts']]] | None = None, variables: list[list[Literal['fV', 'fvx', 'fvy', 'fv', 'fA', 'fax', 'fay', 'fa', 'nbL', 'nbS', 'nbx', 'nby', 'nbE', 'nbV', 'nbvx', 'nbvy', 'nbv', 'nbA', 'nbax', 'nbay', 'nban', 'nbat', 'nba']]] | None = None, history_variables: list[list[Literal['fst', 'nbst', 'fdt', 'nbdt']]] | None = None, num_neighbours: int = 4, blind: list[Literal['nbv', 'nba', 'nbat', 'nban', 'fv', 'fa', 'fat', 'fan']] | None = None, sigma: float = 1.0, remove_outer: float = 0.8, shuffle: Literal['trajectories', 'social_context', ''] = '', loader_fractions: list[float | None] | None = None, target: Literal['turn', 'location', 'acceleration', 'acceleration_gaussian'] = 'turn', **kwargs)[source]#

Train SocialNet model from a set of trajectory files.

Parameters:
  • trajectory_files (Sequence[Path | str] | Path | str) – Trajectory files to train the model on. If a single file is passed, it is converted to a list.

  • output_parent (Path | str | None, optional) – Parent folder where the output will be saved. If None, defaults to the current working directory.

  • session_name (str | None, optional) – Name of the session. If None, defaults to the current date and time.

  • model (Literal["fc", "attention", "interaction"], optional) – Model type: “fc” for fully connected, “attention” for attention-based, “interaction” for interaction-based models.

  • future_steps (int, optional) – Number of future frames to predict. This is the number of frames in the future for which the model will make predictions, by default 32.

  • history_steps (int, optional) – Number of past frames to use. This is the number of frames in the past that the model will consider for making predictions, by default 0.

  • input_types (list[list[INPUT_TYPE]], optional) – List of lists of input types for the model. Each sublist corresponds to a different input type, e.g., “natt”, “n”, “gvars”, “h”, “hvars”, “hbott”, “pts”.

  • variables (list[list[VARIABLES]], optional) –

    List of lists of variable names for each input type “gvars”. The variables can include:

    • fV: focal velocity vector.

    • fvx, fvy: focal velocity x or y components.

    • fv: focal speed.

    • fA: focal acceleration vector.

    • fax, fay: focal acceleration x or y components.

    • fa: norm of focal acceleration vector.

    • nbL or nbS: neighbour position vector.

    • nbx, nby: neighbour position x or y components.

    • nbE: neighbour orientation vector.

    • nbV: neighbour velocity vector.

    • nbvx, nbvy: neighbour velocity x or y components.

    • nbv: neighbour speed.

    • nbA: neighbour acceleration vector.

    • nbax, nbay: neighbour acceleration x or y components.

    • nban, nbat: neighbour acceleration (normal or tangential).

    • nba: norm of neighbour acceleration vector.

  • history_variables (list[list[Literal["fst", "nbst", "fdt", "nbdt"]]], optional) – List of lists of variable names for each input type “hvars”. E.g., “st” (straightness), “dt” (distance travelled).

  • num_neighbours (int, optional) – Number of neighbours to consider in the model.

  • blind (list[Literal["nbv", "nba", "nbat", "nban", "fv", "fa", "fat", "fan"]], optional) – List of variables to blind the model to (e.g., “nbv”, “nba”, etc.).

  • sigma (float, optional) – Data smoothing factor.

  • remove_outer (float, optional) – Fraction of the radius to remove from the outer region.

  • shuffle (Literal["trajectories", "social_context", ""], optional) – Shuffle mode: “trajectories”, “social_context”, or “” (no shuffle).

  • loader_fractions (list[float | None], optional) – Fractions for train, validation, and test splits (e.g., [0.5, 0.01, 0.01]).

  • target (Literal["turn", "location", "acceleration", "acceleration_gaussian"], optional) – What to predict: “turn” (future turning side), “location” (future location), “acceleration”, or “acceleration_gaussian”.

  • **kwargs – Additional keyword arguments for advanced configuration.

Returns:

Dictionary with training results and statistics.

Return type:

dict